天天看点

决策树分裂可视化

# -*- coding: utf-8 -*-
"""
Created on Mon Aug 16 10:43:40 2021

@author: 1
"""
import dtreeviz
import pandas as pd
import numpy as np
from sklearn.datasets import *
from sklearn import tree
'''
iris = load_iris()
df_iris = pd.DataFrame(iris['data'],columns = iris['feature_names'])
df_iris['target'] = iris['target']
'''
data = pd.read_csv('water_data.csv',index_col=0)
data['TN'] = data['TN'].apply(pd.to_numeric, errors='coerce')
data['TEMP'] = data['TEMP'].apply(pd.to_numeric, errors='coerce')
data['COND'] = data['COND'].apply(pd.to_numeric, errors='coerce')
data['TURB'] = data['TURB'].apply(pd.to_numeric, errors='coerce')

data_feature = data.drop(['lable'], axis=1)

feature_names=list(data)
feature_names.pop(9)
target_names=['0','1','2','3','4','5']
data['lable']=data['lable']-1



clf = tree.DecisionTreeClassifier()
clf.fit(data_feature,data['lable'])

import graphviz 
dot_data = tree.export_graphviz(clf, out_file=None, 
                     feature_names=feature_names,  
                     class_names=target_names,  
                     filled=True, rounded=True,  
                     special_characters=True)  

graph = graphviz.Source(dot_data)  
graph 

from dtreeviz.trees import dtreeviz
viz = dtreeviz(clf,
               data_feature,
               data['lable'],
               target_name='',
               feature_names=np.array(feature_names),
               class_names={0:'0',1:'1',2:'2',3:'3',4:'4',5:'5'},scale=3)


viz1 = dtreeviz(clf,
               data_feature,
               data['lable'],
               target_name='',
               feature_names=np.array(feature_names),
               class_names={0:'0',1:'1',2:'2',3:'3',4:'4',5:'5'},
               X=data_feature.loc[0]) 
           
决策树分裂可视化