天天看点

深度学习工程模板

使用方式

下载工程

git clone https://github.com/SpikeKing/DL-Project-Template           

创建和激活虚拟环境

virtualenv venv

source venv/bin/activate

安装Python依赖库

pip install -r requirements.txt           

开发流程

 ●  定义自己的数据加载类,继承DataLoaderBase;

 ●  定义自己的网络结构类,继承ModelBase;

 ●  定义自己的模型训练类,继承TrainerBase;

 ●  定义自己的样本预测类,继承InferBase;

 ●  定义自己的配置文件,写入实验的相关参数;

执行训练模型和预测样本操作。

示例工程

识别MNIST库中手写数字,工程

simple_mnist

训练:

python main_train.py -c configs/simple_mnist_config.json           

预测:

python main_test.py -c configs/simple_mnist_config.json -m simple_m

nist.weights.10-0.24.hdf5

深度学习工程模板

TensorBoard

深度学习工程模板

工程架构

深度学习工程模板

主要组件

DataLoader

操作步骤:

 ●  创建自己的加载数据类,继承DataLoaderBase基类;

 ●  覆写

get_train_data()

get_test_data()

,返回训练和测试数据;

Model

 ●  创建自己的网络结构类,继承ModelBase基类;

build_model()

,创建网络结构;

 ●  在构造器中,调用

build_model()

注意:

plot_model()

支持绘制网络结构;

Trainer

 ●  创建自己的训练类,继承TrainerBase基类;

 ●  参数:网络结构model、训练数据data;

train()

,fit数据,训练网络结构;

注意:支持在训练中调用callbacks,额外添加模型存储、TensorBoard、FPR度量等。

Infer

 ●  创建自己的预测类,继承InferBase基类;

load_model()

,提供模型加载功能;

predict()

,提供样本预测功能;

Config

定义在模型训练过程中所需的参数,JSON格式,支持:学习率、Epoch、Batch等参数。

Main

 ●  创建配置文件config;

 ●  创建数据加载类dataloader;

 ●  创建网络结构类model;

 ●  创建训练类trainer,参数是训练和测试数据、模型;

 ●  执行训练类trainer的train();

 ●  处理预测样本test;

 ●  创建预测类infer;

 ●  执行预测类infer的predict();

原文发布时间为:2018-10-24

本文来自云栖社区合作伙伴“大数据挖掘DT机器学习”,了解相关信息可以关注“

大数据挖掘DT机器学习

”。