天天看點

深度學習工程模闆

使用方式

下載下傳工程

git clone https://github.com/SpikeKing/DL-Project-Template           

建立和激活虛拟環境

virtualenv venv

source venv/bin/activate

安裝Python依賴庫

pip install -r requirements.txt           

開發流程

 ●  定義自己的資料加載類,繼承DataLoaderBase;

 ●  定義自己的網絡結構類,繼承ModelBase;

 ●  定義自己的模型訓練類,繼承TrainerBase;

 ●  定義自己的樣本預測類,繼承InferBase;

 ●  定義自己的配置檔案,寫入實驗的相關參數;

執行訓練模型和預測樣本操作。

示例工程

識别MNIST庫中手寫數字,工程

simple_mnist

訓練:

python main_train.py -c configs/simple_mnist_config.json           

預測:

python main_test.py -c configs/simple_mnist_config.json -m simple_m

nist.weights.10-0.24.hdf5

深度學習工程模闆

TensorBoard

深度學習工程模闆

工程架構

深度學習工程模闆

主要元件

DataLoader

操作步驟:

 ●  建立自己的加載資料類,繼承DataLoaderBase基類;

 ●  覆寫

get_train_data()

get_test_data()

,傳回訓練和測試資料;

Model

 ●  建立自己的網絡結構類,繼承ModelBase基類;

build_model()

,建立網絡結構;

 ●  在構造器中,調用

build_model()

注意:

plot_model()

支援繪制網絡結構;

Trainer

 ●  建立自己的訓練類,繼承TrainerBase基類;

 ●  參數:網絡結構model、訓練資料data;

train()

,fit資料,訓練網絡結構;

注意:支援在訓練中調用callbacks,額外添加模型存儲、TensorBoard、FPR度量等。

Infer

 ●  建立自己的預測類,繼承InferBase基類;

load_model()

,提供模型加載功能;

predict()

,提供樣本預測功能;

Config

定義在模型訓練過程中所需的參數,JSON格式,支援:學習率、Epoch、Batch等參數。

Main

 ●  建立配置檔案config;

 ●  建立資料加載類dataloader;

 ●  建立網絡結構類model;

 ●  建立訓練類trainer,參數是訓練和測試資料、模型;

 ●  執行訓練類trainer的train();

 ●  處理預測樣本test;

 ●  建立預測類infer;

 ●  執行預測類infer的predict();

原文釋出時間為:2018-10-24

本文來自雲栖社群合作夥伴“大資料挖掘DT機器學習”,了解相關資訊可以關注“

大資料挖掘DT機器學習

”。