天天看點

決策樹-泰坦尼克号生還預測

LR和SVM都在某種程度上要求被學習的資料特征和目标之間遵照線性假設。然後許多現實場景下,這種假設不存在。 比如根據年齡預測流感的死亡率,如果用線性模型假設,那隻有兩個可能:年齡越大/越小,死亡率越高。根據經驗,青壯年更不容易因患流感而死亡。年齡和因流感的死亡不存線上性關系。 在機器學習模型中,決策樹是描述非線性關系的不二之選。 信用卡申請的稽核,涉及多項特征,是典型的決策樹模型。對于是否同意申請,是二分類決策任務,隻有yes/no兩種分類結果。 使用多種不同特征組合搭建多層決策樹的情況,模型在學習的時候需要考慮特征節點的選取順序。常用的方式包括資訊熵(Information Gain)和基尼不純性(Gini Impurity)。本文不做讨論。sklearn中預設配置的決策樹模型使用的是Gini impurity作為排序特征的度量名額。 雖然很難擷取信用卡的客戶資料,但有類似的借助客戶檔案進行二分類的任務。 本文進行泰坦尼克号的乘客的生還預測,許多專家嘗試通過計算機模拟和分析找出隐藏在資料背後的生還邏輯。

Python源碼:

#coding=utf-8
import pandas as pd
#-------------data split
from sklearn.cross_validation import train_test_split
#-------------feature transfer
from sklearn.feature_extraction import DictVectorizer
#-------------
from sklearn.tree import DecisionTreeClassifier
#-------------
from sklearn.metrics import classification_report

#-------------download data
titanic=pd.read_csv('http://biostat.mc.vanderbilt.edu/wiki/pub/Main/DataSets/titanic.txt')
print titanic.head()
#transfer to dataFrame format by pandas,use info() to show statistics of data
print titanic.info()
#-------------feature selection
X=titanic[['pclass','age','sex']]
y=titanic['survived']

print 'bf processing\n',X.info()
#-------------feature processing
X['age'].fillna(X['age'].mean(),inplace=True)
print 'af processing\n',X.info
#-------------data split
#75% training set,25% testing set
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.25,random_state=33)
#-------------feature transfer  from String to int
vec=DictVectorizer(sparse=False)
X_train=vec.fit_transform(X_train.to_dict(orient='record'))
#print vec.feature_names  60
#AttributeError: 'DictVectorizer' object has no attribute 'feature_names'
print vec.get_feature_names()
X_test=vec.transform(X_test.to_dict(orient='record'))
#-------------training
#initialize
dtc=DecisionTreeClassifier()
dtc.fit(X_train,y_train)
y_predict=dtc.predict(X_test)
#-------------performance
print 'The Accuracy is',dtc.score(X_test,y_test)
print classification_report(y_test,y_predict,target_names=['died','survived'])
           

Result:    row.names pclass  survived  \

0          1    1st         1

1          2    1st         0

2          3    1st         0

3          4    1st         0

4          5    1st         1

                                              name      age     embarked  \

0                     Allen, Miss Elisabeth Walton  29.0000  Southampton

1                      Allison, Miss Helen Loraine   2.0000  Southampton

2              Allison, Mr Hudson Joshua Creighton  30.0000  Southampton

3  Allison, Mrs Hudson J.C. (Bessie Waldo Daniels)  25.0000  Southampton

4                    Allison, Master Hudson Trevor   0.9167  Southampton

                         home.dest room      ticket   boat     sex

0                     St Louis, MO  B-5  24160 L221      2  female

1  Montreal, PQ / Chesterville, ON  C26         NaN    NaN  female

2  Montreal, PQ / Chesterville, ON  C26         NaN  (135)    male

3  Montreal, PQ / Chesterville, ON  C26         NaN    NaN  female

4  Montreal, PQ / Chesterville, ON  C22         NaN     11    male

<class 'pandas.core.frame.DataFrame'>

RangeIndex: 1313 entries, 0 to 1312

Data columns (total 11 columns):

row.names    1313 non-null int64

pclass       1313 non-null object

survived     1313 non-null int64

name         1313 non-null object

age          633 non-null float64

embarked     821 non-null object

home.dest    754 non-null object

room         77 non-null object

ticket       69 non-null object

boat         347 non-null object

sex          1313 non-null object

dtypes: float64(1), int64(2), object(8)

memory usage: 112.9+ KB

None

bf processing

<class 'pandas.core.frame.DataFrame'>

RangeIndex: 1313 entries, 0 to 1312

Data columns (total 3 columns):

pclass    1313 non-null object

age       633 non-null float64

sex       1313 non-null object

dtypes: float64(1), object(2)

memory usage: 30.8+ KB

None

/Users/mac/workspace/conda/anaconda/lib/python2.7/site-packages/pandas/core/generic.py:3660: SettingWithCopyWarning:

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: http://pandas.pydata.org/pandas-docs/stable/indexing.html#indexing-view-versus-copy

  self._update_inplace(new_data)

af processing

<bound method DataFrame.info of      pclass        age     sex

0       1st  29.000000  female

1       1st   2.000000  female

2       1st  30.000000    male

3       1st  25.000000  female

4       1st   0.916700    male

5       1st  47.000000    male

6       1st  63.000000  female

7       1st  39.000000    male

8       1st  58.000000  female

9       1st  71.000000    male

10      1st  47.000000    male

11      1st  19.000000  female

12      1st  31.194181  female

13      1st  31.194181    male

14      1st  31.194181    male

15      1st  50.000000  female

16      1st  24.000000    male

17      1st  36.000000    male

18      1st  37.000000    male

19      1st  47.000000  female

20      1st  26.000000    male

21      1st  25.000000    male

22      1st  25.000000    male

23      1st  19.000000  female

24      1st  28.000000    male

25      1st  45.000000    male

26      1st  39.000000    male

27      1st  30.000000  female

28      1st  58.000000  female

29      1st  31.194181    male

...     ...        ...     ...

1283    3rd  31.194181  female

1284    3rd  31.194181    male

1285    3rd  31.194181    male

1286    3rd  31.194181    male

1287    3rd  31.194181    male

1288    3rd  31.194181    male

1289    3rd  31.194181    male

1290    3rd  31.194181    male

1291    3rd  31.194181    male

1292    3rd  31.194181    male

1293    3rd  31.194181  female

1294    3rd  31.194181    male

1295    3rd  31.194181    male

1296    3rd  31.194181    male

1297    3rd  31.194181    male

1298    3rd  31.194181    male

1299    3rd  31.194181    male

1300    3rd  31.194181    male

1301    3rd  31.194181    male

1302    3rd  31.194181    male

1303    3rd  31.194181    male

1304    3rd  31.194181  female

1305    3rd  31.194181    male

1306    3rd  31.194181  female

1307    3rd  31.194181  female

1308    3rd  31.194181    male

1309    3rd  31.194181    male

1310    3rd  31.194181    male

1311    3rd  31.194181  female

1312    3rd  31.194181    male

[1313 rows x 3 columns]>

['age', 'pclass=1st', 'pclass=2nd', 'pclass=3rd', 'sex=female', 'sex=male']

The Accuracy is 0.781155015198

             precision    recall  f1-score   support

       died       0.78      0.91      0.84       202

   survived       0.80      0.58      0.67       127

avg / total       0.78      0.78      0.77       329

該資料共有1313條乘客資訊,有些特征資料是缺失的,有些是數值類型,有些是字元串。 預處理環節中特征的選擇十分重要,需要一些背景知識,根據對事故的了解,sex,age,pclass都可能是關鍵因素。 需要完成的資料處理任務: 1.初始的資料中,age列隻有633個需要補充完整,一般,使用平均數或者中位數都是對模型偏離造成最小影響的政策。 2.sex和pclass列的值是列别型,需轉化為數值特征,用0/1代替 算法特點: 相對于其它的模型,決策樹在模型描述上有巨大的優勢,推斷邏輯非常直覺,具有清晰的可解釋性,也友善了模型的可視化。這些特征同時也保證使用該模型時,無需考慮對資料的量化甚至标準化的。與KNN不同,DT仍然屬于有參數模型,需花費更多的時間在訓練資料上。

繼續閱讀