天天看點

Torch7學習筆記[2] ---神經網絡的主體架構

參考資料:

https://github.com/soumith/cvpr2015/blob/master/Deep%20Learning%20with%20Torch.ipynb

将整個架構分為以下幾個子產品:

1、options設定

2、train、test預處理以及讀取

3、net結構以及criterion建立

4、train設定

5、test設定

6、儲存中間結果以及斷點開始(待完善)

y以上每個功能子產品單獨由一個檔案完成,整個結構分為7個檔案

main.lua opt.lua dataloder.lua model.lua train.lua test.lua checkpont.lua(待完善)

require 'torch'
require 'nn'
require 'optim'

local DataLoder = require 'dataloder'  --load the dataloder.lua
local opts = require 'opt'
local Model = require 'model'
local Test = require 'test'
local checkpoints = require 'checkpoint'
local Trainer = require 'train'

torch.setdefaulttensortype = ('torch.FloatTensor')  --
torch.setnumthreads()
torch.manualSeed(opt.manualSeed)
cutorch.manualSeedAll(opt.manualSeed)

local opt = opts.parse(arg)  --load the options
local trainset,testset = DataLoder.creat(opt) --load the dataset

local model,criterion = Model.setup(opt) --load the model,criterion
if(opt.type == 'cuda')   then  --turn on gpu:model-criterion-data-label
    model = model:cuda() 
    criterion = criterion:cuda()
    trainset.data = trainset.data:cuda()
    trainset.label = trainset.label:cuda()
    testset.data = testset.data:cuda()
    testset.label = testset.label:cuda()
end

function trainset:size() --prepare for training 
    return self.data:size() 
end
local trainer = Trainer(model,criterion,opt)

bestModel = false
for epoch = ,opt.max_epoch do
    local current_error = trainer:train(epoch,trainset)
    --save the current station
    --checkpoints.save(epoch, model, trainer.optimState, bestModel, opt)
end

local correct_rate = Test.run(opt,testset,model)
print(correct_rate)
           

運作程式時,直接在檔案所在目錄終端執行:th main.lua 即可運作程式。若需改變options,例如gpu運作:th main.lua –type cuda

繼續閱讀