天天看点

得到XGBoost训练出的模型中的决策树的清晰PDF图像得到XGBoost训练出的模型中的决策树的清晰PDF图像

得到XGBoost训练出的模型中的决策树的清晰PDF图像

直接调用

生成的图像是png格式,过于模糊,想要得到矢量的PDF,以及适合插入到markdown(Typora)中的SVG格式,有一种简单的方式是修改XBGBoost源码,需要修改的文件是

C:\Users\Username\AppData\Local\Programs\Python\Python38\Lib\site-packages\xgboost\plotting.py

,将最后一个地方的函数添加下面两行

g.render('XGBoost_tree'+str(num_trees), format='pdf', cleanup=True)  # 2020-12-19
g.render('XGBoost_tree'+str(num_trees), format='svg', cleanup=True)  # 2020-12-19
           

改成下面这样

def plot_tree(booster, fmap='', num_trees=0, rankdir=None, ax=None, **kwargs):
    """Plot specified tree.

    Parameters
    ----------
    booster : Booster, XGBModel
        Booster or XGBModel instance
    fmap: str (optional)
       The name of feature map file
    num_trees : int, default 0
        Specify the ordinal number of target tree
    rankdir : str, default "TB"
        Passed to graphiz via graph_attr
    ax : matplotlib Axes, default None
        Target axes instance. If None, new figure and axes will be created.
    kwargs :
        Other keywords passed to to_graphviz

    Returns
    -------
    ax : matplotlib Axes

    """
    try:
        from matplotlib import pyplot as plt
        from matplotlib import image
    except ImportError as e:
        raise ImportError('You must install matplotlib to plot tree') from e

    if ax is None:
        _, ax = plt.subplots(1, 1)

    g = to_graphviz(booster, fmap=fmap, num_trees=num_trees, rankdir=rankdir,
                    **kwargs)
    g.render('XGBoost_tree'+str(num_trees),
         format='pdf', cleanup=True)  # 2020-12-19
    g.render('XGBoost_tree'+str(num_trees),
         format='svg', cleanup=True)  # 2020-12-19
    s = BytesIO()
    s.write(g.pipe(format='png'))
    s.seek(0)
    img = image.imread(s)

    ax.imshow(img)
    ax.axis('off')
    return ax
           

这样就会在调用

的目录下生成相应的PDF文件和SVG文件。

以上です。