得到XGBoost训练出的模型中的决策树的清晰PDF图像
直接调用
生成的图像是png格式,过于模糊,想要得到矢量的PDF,以及适合插入到markdown(Typora)中的SVG格式,有一种简单的方式是修改XBGBoost源码,需要修改的文件是
C:\Users\Username\AppData\Local\Programs\Python\Python38\Lib\site-packages\xgboost\plotting.py
,将最后一个地方的函数添加下面两行
g.render('XGBoost_tree'+str(num_trees), format='pdf', cleanup=True) # 2020-12-19
g.render('XGBoost_tree'+str(num_trees), format='svg', cleanup=True) # 2020-12-19
改成下面这样
def plot_tree(booster, fmap='', num_trees=0, rankdir=None, ax=None, **kwargs):
"""Plot specified tree.
Parameters
----------
booster : Booster, XGBModel
Booster or XGBModel instance
fmap: str (optional)
The name of feature map file
num_trees : int, default 0
Specify the ordinal number of target tree
rankdir : str, default "TB"
Passed to graphiz via graph_attr
ax : matplotlib Axes, default None
Target axes instance. If None, new figure and axes will be created.
kwargs :
Other keywords passed to to_graphviz
Returns
-------
ax : matplotlib Axes
"""
try:
from matplotlib import pyplot as plt
from matplotlib import image
except ImportError as e:
raise ImportError('You must install matplotlib to plot tree') from e
if ax is None:
_, ax = plt.subplots(1, 1)
g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir,
**kwargs)
g.render('XGBoost_tree'+str(num_trees),
format='pdf', cleanup=True) # 2020-12-19
g.render('XGBoost_tree'+str(num_trees),
format='svg', cleanup=True) # 2020-12-19
s = BytesIO()
s.write(g.pipe(format='png'))
s.seek(0)
img = image.imread(s)
ax.imshow(img)
ax.axis('off')
return ax
这样就会在调用
的目录下生成相应的PDF文件和SVG文件。
以上です。