使用方式
下載下傳工程
git clone https://github.com/SpikeKing/DL-Project-Template
建立和激活虛拟環境
virtualenv venv
source venv/bin/activate
安裝Python依賴庫
pip install -r requirements.txt
開發流程
● 定義自己的資料加載類,繼承DataLoaderBase;
● 定義自己的網絡結構類,繼承ModelBase;
● 定義自己的模型訓練類,繼承TrainerBase;
● 定義自己的樣本預測類,繼承InferBase;
● 定義自己的配置檔案,寫入實驗的相關參數;
執行訓練模型和預測樣本操作。
示例工程
識别MNIST庫中手寫數字,工程
simple_mnist
訓練:
python main_train.py -c configs/simple_mnist_config.json
預測:
python main_test.py -c configs/simple_mnist_config.json -m simple_m
nist.weights.10-0.24.hdf5

TensorBoard
工程架構
主要元件
DataLoader
操作步驟:
● 建立自己的加載資料類,繼承DataLoaderBase基類;
● 覆寫
get_train_data()
和
get_test_data()
,傳回訓練和測試資料;
Model
● 建立自己的網絡結構類,繼承ModelBase基類;
build_model()
,建立網絡結構;
● 在構造器中,調用
build_model()
;
注意:
plot_model()
支援繪制網絡結構;
Trainer
● 建立自己的訓練類,繼承TrainerBase基類;
● 參數:網絡結構model、訓練資料data;
train()
,fit資料,訓練網絡結構;
注意:支援在訓練中調用callbacks,額外添加模型存儲、TensorBoard、FPR度量等。
Infer
● 建立自己的預測類,繼承InferBase基類;
load_model()
,提供模型加載功能;
predict()
,提供樣本預測功能;
Config
定義在模型訓練過程中所需的參數,JSON格式,支援:學習率、Epoch、Batch等參數。
Main
● 建立配置檔案config;
● 建立資料加載類dataloader;
● 建立網絡結構類model;
● 建立訓練類trainer,參數是訓練和測試資料、模型;
● 執行訓練類trainer的train();
● 處理預測樣本test;
● 建立預測類infer;
● 執行預測類infer的predict();
原文釋出時間為:2018-10-24
本文來自雲栖社群合作夥伴“大資料挖掘DT機器學習”,了解相關資訊可以關注“
大資料挖掘DT機器學習”。