参考资料:
https://github.com/soumith/cvpr2015/blob/master/Deep%20Learning%20with%20Torch.ipynb
将整个框架分为以下几个模块:
1、options设置
2、train、test预处理以及读取
3、net结构以及criterion建立
4、train设置
5、test设置
6、保存中间结果以及断点开始(待完善)
y以上每个功能模块单独由一个文件完成,整个结构分为7个文件
main.lua
opt.lua
dataloder.lua
model.lua
train.lua
test.lua
checkpont.lua(待完善)
require 'torch'
require 'nn'
require 'optim'
local DataLoder = require 'dataloder' --load the dataloder.lua
local opts = require 'opt'
local Model = require 'model'
local Test = require 'test'
local checkpoints = require 'checkpoint'
local Trainer = require 'train'
torch.setdefaulttensortype = ('torch.FloatTensor') --
torch.setnumthreads()
torch.manualSeed(opt.manualSeed)
cutorch.manualSeedAll(opt.manualSeed)
local opt = opts.parse(arg) --load the options
local trainset,testset = DataLoder.creat(opt) --load the dataset
local model,criterion = Model.setup(opt) --load the model,criterion
if(opt.type == 'cuda') then --turn on gpu:model-criterion-data-label
model = model:cuda()
criterion = criterion:cuda()
trainset.data = trainset.data:cuda()
trainset.label = trainset.label:cuda()
testset.data = testset.data:cuda()
testset.label = testset.label:cuda()
end
function trainset:size() --prepare for training
return self.data:size()
end
local trainer = Trainer(model,criterion,opt)
bestModel = false
for epoch = ,opt.max_epoch do
local current_error = trainer:train(epoch,trainset)
--save the current station
--checkpoints.save(epoch, model, trainer.optimState, bestModel, opt)
end
local correct_rate = Test.run(opt,testset,model)
print(correct_rate)
运行程序时,直接在文件所在目录终端执行:th main.lua 即可运行程序。若需改变options,例如gpu运行:th main.lua –type cuda