天天看点

蜡烛图、美国图绘图及标记

来源:https://www.joinquant.com/view/community/detail/92d2ccab2d412dbfa7df366369e6373b?type=1

import numpy as np
import pandas as pd
import empyrical as ep
import statsmodels.api as sm

import alphalens as al
from alphalens import plotting
import alphalens.performance as perf

from jqdata import *
from sqlalchemy.sql import func
from jqfactor import (Factor, calc_factors, neutralize,
                      standardlize, get_factor_values)

from functools import reduce
from tqdm import tqdm_notebook
from typing import (Tuple, List)
from dateutil.parser import parse

import seaborn as sns
import matplotlib.pyplot as plt


mpl.rcParams['font.family'] = 'serif'  # pd.plot中文
# 用来正常显示负号
mpl.rcParams['axes.unicode_minus'] = False
# 图表主题
plt.style.use('seaborn')
           
# 蜡烛图

import matplotlib.dates as mdate
import matplotlib.ticker as ticker
from matplotlib.path import Path
from matplotlib.patches import PathPatch


# 画出蜡烛图
def plot_candlestick(df:pd.DataFrame,title:str='',**kwargs):
    '''
    画出蜡烛图
    -----------
        price:index-date columns-OHLC
                index为datetime
        kwargs:为pathpatch时则画出需要标记的k线
    '''

    df = df.copy()
    df.index.names = ['date']
    df = df.reset_index()
    data = df[['date','open','high','low','close']]


    # 生成横轴的刻度名字
    date_tickers = df['date'].dt.strftime('%Y-%m-%d').values

    day_quotes=[tuple([i]+list(quote[1:])) for i,quote in enumerate(data.values)]
    mpl.rcParams['font.family'] = 'serif'
    fig, ax = plt.subplots(figsize=(18,4))
    plt.title(title)
    def format_date(x,pos=None):
        if x<0 or x>len(date_tickers)-1:
            return ''
        return date_tickers[int(x)]
    
    candlestick_ohlc(ax,day_quotes,colordown='g', colorup='r',width=0.2)
    
    if 'pathpatch' in kwargs:
        ax.add_patch(kwargs['pathpatch'])
    
    ax.xaxis.set_major_locator(ticker.MultipleLocator(6))
    ax.xaxis.set_major_formatter(ticker.FuncFormatter(format_date));
    ax.grid(True)
    
# 标记需要标记的K线
def get_mark_data(price:pd.DataFrame,target_date:list):
    
    '''
    标记出k线
    -----------
        price:index-date columns-OHLC
            index为datetime
        target_date:list 日期格式yyyy-mm-dd
    '''
    
    df = price[['open','high','low','close']].copy()
    df.index = df.index.strftime('%Y-%m-%d')
    
    if isinstance(target_date,list):
        
        target_data = [target_date]

    vertices = []
    codes = []
    
    idx = [df.index.get_loc(i) for i in target_date]
    
    
    for i in idx:
        
        low = df['low'].iloc[i] * (1 - 0.001)
        high = df['high'].iloc[i] * (1 + 0.001)
        
        codes += [Path.MOVETO] + [Path.LINETO]*3 + [Path.CLOSEPOLY]
        vertices += [(i - 0.5, low), (i - 0.5, high), (i + 0.5, high), (i + 0.5, low), (i - 0.5, low)]

    path = Path(vertices, codes)
    pathpatch = PathPatch(path, facecolor='None', edgecolor='black',lw=2)
    
    return pathpatch

    
def candlestick_ohlc(ax, quotes, width=0.2, colorup='k', colordown='r',
                     alpha=1.0):
    """
    Plot the time, open, high, low, close as a vertical line ranging
    from low to high.  Use a rectangular bar to represent the
    open-close span.  If close >= open, use colorup to color the bar,
    otherwise use colordown
    Parameters
    ----------
    ax : `Axes`
        an Axes instance to plot to
    quotes : sequence of (time, open, high, low, close, ...) sequences
        As long as the first 5 elements are these values,
        the record can be as long as you want (e.g., it may store volume).
        time must be in float days format - see date2num
    width : float
        fraction of a day for the rectangle width
    colorup : color
        the color of the rectangle where close >= open
    colordown : color
         the color of the rectangle where close <  open
    alpha : float
        the rectangle alpha level
    Returns
    -------
    ret : tuple
        returns (lines, patches) where lines is a list of lines
        added and patches is a list of the rectangle patches added
    """
    return _candlestick(ax, quotes, width=width, colorup=colorup,
                        colordown=colordown,
                        alpha=alpha, ochl=False)


def _candlestick(ax, quotes, width=0.2, colorup='k', colordown='r',
                 alpha=1.0, ochl=True):
    """
    Plot the time, open, high, low, close as a vertical line ranging
    from low to high.  Use a rectangular bar to represent the
    open-close span.  If close >= open, use colorup to color the bar,
    otherwise use colordown
    Parameters
    ----------
    ax : `Axes`
        an Axes instance to plot to
    quotes : sequence of quote sequences
        data to plot.  time must be in float date format - see date2num
        (time, open, high, low, close, ...) vs
        (time, open, close, high, low, ...)
        set by `ochl`
    width : float
        fraction of a day for the rectangle width
    colorup : color
        the color of the rectangle where close >= open
    colordown : color
         the color of the rectangle where close <  open
    alpha : float
        the rectangle alpha level
    ochl: bool
        argument to select between ochl and ohlc ordering of quotes
    Returns
    -------
    ret : tuple
        returns (lines, patches) where lines is a list of lines
        added and patches is a list of the rectangle patches added
    """

    OFFSET = width / 2.0

    lines = []
    patches = []
    for q in quotes:
        if ochl:
            t, open, close, high, low = q[:5]
        else:
            t, open, high, low, close = q[:5]

        if close >= open:
            color = colorup
            lower = open
            height = close - open
        else:
            color = colordown
            lower = close
            height = open - close

        vline = Line2D(
            xdata=(t, t), ydata=(low, high),
            color=color,
            linewidth=0.5,
            antialiased=True,
        )

        rect = Rectangle(
            xy=(t - OFFSET, lower),
            width=width,
            height=height,
            facecolor=color,
            edgecolor=color,
        )
        rect.set_alpha(alpha)

        lines.append(vline)
        patches.append(rect)
        ax.add_line(vline)
        ax.add_patch(rect)
    ax.autoscale_view()

    return lines, patches


def _check_input(opens, closes, highs, lows, miss=-1):
    """Checks that *opens*, *highs*, *lows* and *closes* have the same length.
    NOTE: this code assumes if any value open, high, low, close is
    missing (*-1*) they all are missing
    Parameters
    ----------
    ax : `Axes`
        an Axes instance to plot to
    opens : sequence
        sequence of opening values
    highs : sequence
        sequence of high values
    lows : sequence
        sequence of low values
    closes : sequence
        sequence of closing values
    miss : int
        identifier of the missing data
    Raises
    ------
    ValueError
        if the input sequences don't have the same length
    """

    def _missing(sequence, miss=-1):
        """Returns the index in *sequence* of the missing data, identified by
        *miss*
        Parameters
        ----------
        sequence :
            sequence to evaluate
        miss :
            identifier of the missing data
        Returns
        -------
        where_miss: numpy.ndarray
            indices of the missing data
        """
        return np.where(np.array(sequence) == miss)[0]

    same_length = len(opens) == len(highs) == len(lows) == len(closes)
    _missopens = _missing(opens)
    same_missing = ((_missopens == _missing(highs)).all() and
                    (_missopens == _missing(lows)).all() and
                    (_missopens == _missing(closes)).all())

    if not (same_length and same_missing):
        msg = ("*opens*, *highs*, *lows* and *closes* must have the same"
               " length. NOTE: this code assumes if any value open, high,"
               " low, close is missing (*-1*) they all must be missing.")
        raise ValueError(msg)
           
# 画美国线
def plot_HLC_bar(df:pd.DataFrame,title:str='',**kwargs):
    
    '''
    画出蜡烛图
    -----------
        price:index-date columns-HLC
                index为datetime
        kwargs:为pathpatch时则画出需要标记的k线
    '''
    # 主体
    u_vertices1 = []
    u_codes1 = []
    
    # 辅
    u_codes2 = []
    u_vertices2 = []
    
    # 主题
    d_vertices1 = []
    d_codes1 = []
    
    # 辅
    d_codes2 = []
    d_vertices2 = []
    
    hlc = df.reset_index(drop=True)
    for idx,row in hlc.iterrows():
        
        low = row['low']
        high = row['high']
        close = row['close']
        open_ = row['open']
        
        
        if open_ < close:
            # 上涨部分
            u_codes1 += [Path.MOVETO] + [Path.LINETO] 
            u_codes2 += [Path.MOVETO] + [Path.LINETO]
            u_vertices1 += [(idx, low), (idx, high)]
            u_vertices2 += [(idx, close), (idx+0.2, close)]
        else:
            # 下跌部分
            d_codes1 += [Path.MOVETO] + [Path.LINETO]
            d_codes2 += [Path.MOVETO] + [Path.LINETO]
            d_vertices1 += [(idx, low), (idx, high)]
            d_vertices2 += [(idx, close), (idx+0.2, close)]
    
    # 上涨步伐
    path1 = Path(u_vertices1, u_codes1)
    path2 = Path(u_vertices2,u_codes2)
    
    bar1 = PathPatch(path1,lw=1.5,edgecolor='red')
    line1 = PathPatch(path2,lw=0.6,edgecolor='red')
    
    # 下跌部分
    path3 = Path(d_vertices1, d_codes1)
    path4 = Path(d_vertices2,d_codes2)
    
    bar2 = PathPatch(path3,lw=1.5,edgecolor='g')
    line2 = PathPatch(path4,lw=0.6,edgecolor='g')
    
    fig, ax = plt.subplots(figsize=(18,4))
    plt.title(title)
    ax.add_patch(bar1)
    ax.add_patch(line1)
    
    ax.add_patch(bar2)
    ax.add_patch(line2)
    ax.set_xlim(-0.5, len(df))
    ax.set_ylim(df['low'].min() * (1 - 0.01), df['high'].max() * (1 + 0.01))
    
    if 'pathpatch' in kwargs:
        ax.add_patch(kwargs['pathpatch'])
    
    def format_date(x,pos=None):
        if x<0 or x>len(date_tickers)-1:
            return ''
        return date_tickers[int(x)]
    
    date_tickers = df.index.strftime('%Y-%m-%d').values
    ax.xaxis.set_major_locator(ticker.MultipleLocator(6))
    ax.xaxis.set_major_formatter(ticker.FuncFormatter(format_date))
    ax.grid(True);
           
price = get_price('000001.XSHG','2019-10-10','2019-12-25')

plot_candlestick(price,'上证指数蜡烛图(2019/10/10-2019/12/25)',
                 pathpatch=get_mark_data(price,['2019-10-14','2019-11-05','2019-12-17']))
           
蜡烛图、美国图绘图及标记
plot_HLC_bar(price,'上证指数蜡烛图举例(2020/01/16-2020/02/24)',
             pathpatch=get_mark_data(price,['2020-01-23','2020-02-04']))
           
蜡烛图、美国图绘图及标记