天天看點

簡明 TensorFlow 教程 — 第二部分:混合學習

<b>本文講的是簡明 TensorFlow 教程 — 第二部分:混合學習,</b>

在本文中,我們将示範一個寬 N 深度網絡,它使用廣泛的線性模型與前饋網絡同時訓練,以證明它比一些傳統的機器學習技術能提供精度更高的預測結果。下面我們将使用混合學習方法預測泰坦尼克号乘客的生存機率。

混合學習技術已被 Google 應用在 Play 商店中提供應用推薦。Youtube 也在使用類似的混合學習技術來推薦視訊。

寬和深網絡将線性模型與前饋神經網絡結合,使得我們的預測将具有記憶和通用化。 這種類型的模型可以用于分類和回歸問題。 這種方法能夠在減少特征工程的同時擁有相對精确的預測結果,可謂一箭雙雕。

簡明 TensorFlow 教程 — 第二部分:混合學習

首先,我們要将所有列定義為連續或分類。

連續的列 - 連續範圍内的任何數值。 像錢或年齡。

分類列 - 有限集的一部分。 像男性或女性,或着乘客的國籍。

因為我們隻是想看看一個人是否幸存下來,這是一個二進制分類問題。 是以預測結果 1 表示該乘客幸存下來,而結果 0 表示沒有幸存。(也即建立一列來儲存預測結果)

現在我們可以建立列和添加嵌入層。 當我們建構我們的模型時,我們想要将我們的分類列變成稀疏列。 對于沒有那麼多類别(例如 Sex 或 Embarked(S,Q 或 C))的列,我們根據類名将它們轉換為稀疏列。

對于類别較多的分類列,由于我們沒有一個詞彙表檔案将所有可能的類别映射為一個整數,是以我們使用哈希值作為鍵值。

我們的連續列使用的是真實的值。 因為 passengerId 是連續的而不是分類的,并且他們已經是整數的 ID 而不是字元串。

我們需要根據年齡對乘客進行分類。 桶化(Bucketization )允許我們找到乘客對應年齡組的生存相關性,而不是将所有年齡作為一個大整體,進而提高我們的準确性。

最後,我們将定義我們的廣度列和深度列。 我們的寬列将有效地記住我們與特征之間的互動。 我們的寬列不會将我們的特征通用化,這是深度列的用處。

擁有這些深度列的好處是,它會将我們提供的高次元稀疏的特征進行降維來計算。

我們通過使用深度列和廣度列來建立分類器,以完成我們的函數。

我們在運作網絡之前要做的最後一件事是為我們的連續和分類列建立映射。 我們先建立一個輸入函數給我們的資料框,它能将我們的資料框轉換為 Tensorflow 可以操作的對象。 這樣做的好處是,我們可以改變和調整我們的 tensors 建立過程。 例如說我們可以将特征列傳遞到 .fit .feature .predict 作為一個單獨建立的列,就像我們上面所描述的一樣,但這個是一個更加簡潔的方案。

現在,做完了以上工作,我們就可以開始編寫訓練功能了

我們讀取預處理後的 csv 檔案,像處理缺失值等。為了讓文章保持簡潔,更多有關預處理的代碼和内容可以在代碼倉庫中找到。

這些 csv 檔案将通過調用 input_fn 函數轉換為 tensors 。 我們先建構評價名額,然後列印我們的預測和評估結果。

簡明 TensorFlow 教程 — 第二部分:混合學習

網絡結果

運作我們的代碼為我們提供了相當好的結果,不需要添加任何額外的列或做任何特征工程。 而且隻要很少的微調這個模型可以得到相對較好的結果。

簡明 TensorFlow 教程 — 第二部分:混合學習

與傳統廣度線性模型一起添加嵌入層的能力,允許通過将稀疏次元降低到低次元來進行準确的預測。

<b></b>

<b>原文釋出時間為:2016年12月20日</b>

<b>本文來自雲栖社群合作夥伴掘金,了解相關資訊可以關注掘金網站。</b>

繼續閱讀