天天看點

機器學習之決策樹實踐:隐形眼鏡類型預測

步驟:

收集資料:使用書中提供的小型資料集

準備資料:對文本中的資料進行預處理,如解析資料行

分析資料:快速檢查資料,并使用createPlot()函數繪制最終的樹形圖

訓練決策樹:使用createTree()函數訓練

測試決策樹:編寫簡單的測試函數驗證決策樹的輸出結果&繪圖結果

使用決策樹:這部分可選擇将訓練好的決策樹進行存儲,以便随時使用

1、資料集

young	myope	no	reduced	no lenses
young	myope	no	normal	soft
young	myope	yes	reduced	no lenses
young	myope	yes	normal	hard
young	hyper	no	reduced	no lenses
young	hyper	no	normal	soft
young	hyper	yes	reduced	no lenses
young	hyper	yes	normal	hard
pre	myope	no	reduced	no lenses
pre	myope	no	normal	soft
pre	myope	yes	reduced	no lenses
pre	myope	yes	normal	hard
pre	hyper	no	reduced	no lenses
pre	hyper	no	normal	soft
pre	hyper	yes	reduced	no lenses
pre	hyper	yes	normal	no lenses
presbyopic	myope	no	reduced	no lenses
presbyopic	myope	no	normal	no lenses
presbyopic	myope	yes	reduced	no lenses
presbyopic	myope	yes	normal	hard
presbyopic	hyper	no	reduced	no lenses
presbyopic	hyper	no	normal	soft
presbyopic	hyper	yes	reduced	no lenses
presbyopic	hyper	yes	normal	no lenses
           

2、代碼如下

# -*- coding: UTF-8 -*-
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.externals.six import StringIO
from sklearn import tree
import pandas as pd
import numpy as np
import pydotplus

if __name__ == '__main__':
	with open('lenses.txt', 'r') as fr:										#加載檔案
		lenses = [inst.strip().split('\t') for inst in fr.readlines()]		#處理檔案
	lenses_target = []														#提取每組資料的類别,儲存在清單裡
	for each in lenses:
		lenses_target.append(each[-1])
	# print(lenses_target)

	lensesLabels = ['age', 'prescript', 'astigmatic', 'tearRate']			#特征标簽		
	lenses_list = []														#儲存lenses資料的臨時清單
	lenses_dict = {}														#儲存lenses資料的字典,用于生成pandas
	for each_label in lensesLabels:											#提取資訊,生成字典
		for each in lenses:
			lenses_list.append(each[lensesLabels.index(each_label)])
		lenses_dict[each_label] = lenses_list
		lenses_list = []
	# print(lenses_dict)														#列印字典資訊
	lenses_pd = pd.DataFrame(lenses_dict)									#生成pandas.DataFrame
	# print(lenses_pd)														#列印pandas.DataFrame
	le = LabelEncoder()														#建立LabelEncoder()對象,用于序列化			
	for col in lenses_pd.columns:											#序列化
		lenses_pd[col] = le.fit_transform(lenses_pd[col])
	# print(lenses_pd)														#列印編碼資訊

	clf = tree.DecisionTreeClassifier(max_depth = 4)						#建立DecisionTreeClassifier()類
	clf = clf.fit(lenses_pd.values.tolist(), lenses_target)					#使用資料,建構決策樹

	dot_data = StringIO()
	tree.export_graphviz(clf, out_file = dot_data,							#繪制決策樹
						feature_names = lenses_pd.keys(),
						class_names = clf.classes_,
						filled=True, rounded=True,
						special_characters=True)
	graph = pydotplus.graph_from_dot_data(dot_data.getvalue())
	graph.write_pdf("tree.pdf")												#儲存繪制好的決策樹,以PDF的形式存儲。

	print(clf.predict([[1,1,1,0]]))			
           

繼續閱讀