來源:https://www.joinquant.com/view/community/detail/92d2ccab2d412dbfa7df366369e6373b?type=1
import numpy as np
import pandas as pd
import empyrical as ep
import statsmodels.api as sm
import alphalens as al
from alphalens import plotting
import alphalens.performance as perf
from jqdata import *
from sqlalchemy.sql import func
from jqfactor import (Factor, calc_factors, neutralize,
standardlize, get_factor_values)
from functools import reduce
from tqdm import tqdm_notebook
from typing import (Tuple, List)
from dateutil.parser import parse
import seaborn as sns
import matplotlib.pyplot as plt
mpl.rcParams['font.family'] = 'serif' # pd.plot中文
# 用來正常顯示負号
mpl.rcParams['axes.unicode_minus'] = False
# 圖表主題
plt.style.use('seaborn')
# 蠟燭圖
import matplotlib.dates as mdate
import matplotlib.ticker as ticker
from matplotlib.path import Path
from matplotlib.patches import PathPatch
# 畫出蠟燭圖
def plot_candlestick(df:pd.DataFrame,title:str='',**kwargs):
'''
畫出蠟燭圖
-----------
price:index-date columns-OHLC
index為datetime
kwargs:為pathpatch時則畫出需要标記的k線
'''
df = df.copy()
df.index.names = ['date']
df = df.reset_index()
data = df[['date','open','high','low','close']]
# 生成橫軸的刻度名字
date_tickers = df['date'].dt.strftime('%Y-%m-%d').values
day_quotes=[tuple([i]+list(quote[1:])) for i,quote in enumerate(data.values)]
mpl.rcParams['font.family'] = 'serif'
fig, ax = plt.subplots(figsize=(18,4))
plt.title(title)
def format_date(x,pos=None):
if x<0 or x>len(date_tickers)-1:
return ''
return date_tickers[int(x)]
candlestick_ohlc(ax,day_quotes,colordown='g', colorup='r',width=0.2)
if 'pathpatch' in kwargs:
ax.add_patch(kwargs['pathpatch'])
ax.xaxis.set_major_locator(ticker.MultipleLocator(6))
ax.xaxis.set_major_formatter(ticker.FuncFormatter(format_date));
ax.grid(True)
# 标記需要标記的K線
def get_mark_data(price:pd.DataFrame,target_date:list):
'''
标記出k線
-----------
price:index-date columns-OHLC
index為datetime
target_date:list 日期格式yyyy-mm-dd
'''
df = price[['open','high','low','close']].copy()
df.index = df.index.strftime('%Y-%m-%d')
if isinstance(target_date,list):
target_data = [target_date]
vertices = []
codes = []
idx = [df.index.get_loc(i) for i in target_date]
for i in idx:
low = df['low'].iloc[i] * (1 - 0.001)
high = df['high'].iloc[i] * (1 + 0.001)
codes += [Path.MOVETO] + [Path.LINETO]*3 + [Path.CLOSEPOLY]
vertices += [(i - 0.5, low), (i - 0.5, high), (i + 0.5, high), (i + 0.5, low), (i - 0.5, low)]
path = Path(vertices, codes)
pathpatch = PathPatch(path, facecolor='None', edgecolor='black',lw=2)
return pathpatch
def candlestick_ohlc(ax, quotes, width=0.2, colorup='k', colordown='r',
alpha=1.0):
"""
Plot the time, open, high, low, close as a vertical line ranging
from low to high. Use a rectangular bar to represent the
open-close span. If close >= open, use colorup to color the bar,
otherwise use colordown
Parameters
----------
ax : `Axes`
an Axes instance to plot to
quotes : sequence of (time, open, high, low, close, ...) sequences
As long as the first 5 elements are these values,
the record can be as long as you want (e.g., it may store volume).
time must be in float days format - see date2num
width : float
fraction of a day for the rectangle width
colorup : color
the color of the rectangle where close >= open
colordown : color
the color of the rectangle where close < open
alpha : float
the rectangle alpha level
Returns
-------
ret : tuple
returns (lines, patches) where lines is a list of lines
added and patches is a list of the rectangle patches added
"""
return _candlestick(ax, quotes, width=width, colorup=colorup,
colordown=colordown,
alpha=alpha, ochl=False)
def _candlestick(ax, quotes, width=0.2, colorup='k', colordown='r',
alpha=1.0, ochl=True):
"""
Plot the time, open, high, low, close as a vertical line ranging
from low to high. Use a rectangular bar to represent the
open-close span. If close >= open, use colorup to color the bar,
otherwise use colordown
Parameters
----------
ax : `Axes`
an Axes instance to plot to
quotes : sequence of quote sequences
data to plot. time must be in float date format - see date2num
(time, open, high, low, close, ...) vs
(time, open, close, high, low, ...)
set by `ochl`
width : float
fraction of a day for the rectangle width
colorup : color
the color of the rectangle where close >= open
colordown : color
the color of the rectangle where close < open
alpha : float
the rectangle alpha level
ochl: bool
argument to select between ochl and ohlc ordering of quotes
Returns
-------
ret : tuple
returns (lines, patches) where lines is a list of lines
added and patches is a list of the rectangle patches added
"""
OFFSET = width / 2.0
lines = []
patches = []
for q in quotes:
if ochl:
t, open, close, high, low = q[:5]
else:
t, open, high, low, close = q[:5]
if close >= open:
color = colorup
lower = open
height = close - open
else:
color = colordown
lower = close
height = open - close
vline = Line2D(
xdata=(t, t), ydata=(low, high),
color=color,
linewidth=0.5,
antialiased=True,
)
rect = Rectangle(
xy=(t - OFFSET, lower),
width=width,
height=height,
facecolor=color,
edgecolor=color,
)
rect.set_alpha(alpha)
lines.append(vline)
patches.append(rect)
ax.add_line(vline)
ax.add_patch(rect)
ax.autoscale_view()
return lines, patches
def _check_input(opens, closes, highs, lows, miss=-1):
"""Checks that *opens*, *highs*, *lows* and *closes* have the same length.
NOTE: this code assumes if any value open, high, low, close is
missing (*-1*) they all are missing
Parameters
----------
ax : `Axes`
an Axes instance to plot to
opens : sequence
sequence of opening values
highs : sequence
sequence of high values
lows : sequence
sequence of low values
closes : sequence
sequence of closing values
miss : int
identifier of the missing data
Raises
------
ValueError
if the input sequences don't have the same length
"""
def _missing(sequence, miss=-1):
"""Returns the index in *sequence* of the missing data, identified by
*miss*
Parameters
----------
sequence :
sequence to evaluate
miss :
identifier of the missing data
Returns
-------
where_miss: numpy.ndarray
indices of the missing data
"""
return np.where(np.array(sequence) == miss)[0]
same_length = len(opens) == len(highs) == len(lows) == len(closes)
_missopens = _missing(opens)
same_missing = ((_missopens == _missing(highs)).all() and
(_missopens == _missing(lows)).all() and
(_missopens == _missing(closes)).all())
if not (same_length and same_missing):
msg = ("*opens*, *highs*, *lows* and *closes* must have the same"
" length. NOTE: this code assumes if any value open, high,"
" low, close is missing (*-1*) they all must be missing.")
raise ValueError(msg)
# 畫美國線
def plot_HLC_bar(df:pd.DataFrame,title:str='',**kwargs):
'''
畫出蠟燭圖
-----------
price:index-date columns-HLC
index為datetime
kwargs:為pathpatch時則畫出需要标記的k線
'''
# 主體
u_vertices1 = []
u_codes1 = []
# 輔
u_codes2 = []
u_vertices2 = []
# 主題
d_vertices1 = []
d_codes1 = []
# 輔
d_codes2 = []
d_vertices2 = []
hlc = df.reset_index(drop=True)
for idx,row in hlc.iterrows():
low = row['low']
high = row['high']
close = row['close']
open_ = row['open']
if open_ < close:
# 上漲部分
u_codes1 += [Path.MOVETO] + [Path.LINETO]
u_codes2 += [Path.MOVETO] + [Path.LINETO]
u_vertices1 += [(idx, low), (idx, high)]
u_vertices2 += [(idx, close), (idx+0.2, close)]
else:
# 下跌部分
d_codes1 += [Path.MOVETO] + [Path.LINETO]
d_codes2 += [Path.MOVETO] + [Path.LINETO]
d_vertices1 += [(idx, low), (idx, high)]
d_vertices2 += [(idx, close), (idx+0.2, close)]
# 上漲步伐
path1 = Path(u_vertices1, u_codes1)
path2 = Path(u_vertices2,u_codes2)
bar1 = PathPatch(path1,lw=1.5,edgecolor='red')
line1 = PathPatch(path2,lw=0.6,edgecolor='red')
# 下跌部分
path3 = Path(d_vertices1, d_codes1)
path4 = Path(d_vertices2,d_codes2)
bar2 = PathPatch(path3,lw=1.5,edgecolor='g')
line2 = PathPatch(path4,lw=0.6,edgecolor='g')
fig, ax = plt.subplots(figsize=(18,4))
plt.title(title)
ax.add_patch(bar1)
ax.add_patch(line1)
ax.add_patch(bar2)
ax.add_patch(line2)
ax.set_xlim(-0.5, len(df))
ax.set_ylim(df['low'].min() * (1 - 0.01), df['high'].max() * (1 + 0.01))
if 'pathpatch' in kwargs:
ax.add_patch(kwargs['pathpatch'])
def format_date(x,pos=None):
if x<0 or x>len(date_tickers)-1:
return ''
return date_tickers[int(x)]
date_tickers = df.index.strftime('%Y-%m-%d').values
ax.xaxis.set_major_locator(ticker.MultipleLocator(6))
ax.xaxis.set_major_formatter(ticker.FuncFormatter(format_date))
ax.grid(True);
price = get_price('000001.XSHG','2019-10-10','2019-12-25')
plot_candlestick(price,'上證指數蠟燭圖(2019/10/10-2019/12/25)',
pathpatch=get_mark_data(price,['2019-10-14','2019-11-05','2019-12-17']))

plot_HLC_bar(price,'上證指數蠟燭圖舉例(2020/01/16-2020/02/24)',
pathpatch=get_mark_data(price,['2020-01-23','2020-02-04']))