天天看點

TensorFlow的高階接口Estimator的使用(1)

在《TensorFlow機器學習項目實戰》的4.4節,作者使用了skflow。skflow剛出來的時候火了一陣,但是接口變化非常頻繁,是以後來用的人也越來越少,也導緻4.4的程式不能運作了。

但是最近釋出的TensorFlow 1.4中,我們發現該子產品已經內建到了核心子產品,意味着接口基本穩定下來,并有推廣使用的趨勢。是以我把4.4的程式重新用Estimator寫了一下,變量名基本保持不變,代碼如下:

# -*- coding: utf-8 -*-    
 import tensorflow as tf from sklearn import datasets, metrics,
  preprocessing import numpy as np import pandas as pd import os df = pd.
  read_csv("data/CHD.csv", header=0) print( df.describe()) X=df['age'].
  astype(float) feature_columns = [tf.contrib.layers.real_valued_column
  ("X", dimension=1)] classifier = tf.estimator.LinearClassifier
  (feature_columns=feature_columns,                                   
  model_dir=os.path.join(".","tmp","logistic")) #classifier =
   tf.estimator.LinearClassifier(feature_columns=feature_columns) 
input_fn_train= tf.estimator.inputs.numpy_input_fn(              
 x={"X" : np.array(X)},           
  y=np.array(df['chd']),            
 batch_size=2,                
 num_epochs=None,              
shuffle=True) classifier.train(input_fn=input_fn_train,steps=2000)
 #模型的準确度 score =  classifier.evaluate(input_fn=input_fn_train,steps=50)
 ["accuracy"] print("Accuracy: %f" % score)      

注:這段程式可以在Ubuntu和MacOS下面跑,但是Windows下面還不行,是路徑的問題。這應該是Estimator的一個BUG,在contrib.learn下也是一樣的不行。如果想在windows下,一定要用注釋掉的部分。

這裡面最難寫的是input_fn函數,也是最重要的函數,我在這段程式中直接使用了numpy_input_fn來建構。[1]中除了這個方法還給出了從pandas建構的方法,大家可以自己嘗試。

input_fn帶來了一個好處,就是可以按照生産者消費者模式讀取資料,具體的解釋可以參考[2]。簡單的解釋,就是IO一般都比較慢,我們需要在資料處理的過程中進行讀取資料,那樣就可以充分的節省時間,這樣就設計多線程在背景不斷的取資料。

feature_colums的建構需要一定的技巧,這個主要參考[3]

另外的一個變化就是模型的準确度不再是用metric子產品,而是Estimator自帶的子產品。

如果大家有什麼問題歡迎留言

Reference

[1] ​​為Estimator建構輸入函數​​