通過前面幾章的學習,我們已經掌握了PyTorch中大部分的基礎知識,本章将結合之前講的内容,帶領讀者從頭實作一個完整的深度學習項目。本章的重點不在于如何使用PyTorch的接口,而在于合理地設計程式的結構,使得程式更具可讀性、更易用。
6.1 程式設計實戰:貓和狗二分類
在學習某個深度學習架構時,掌握其基本知識和接口固然重要,但如何合理組織代碼,使得代碼具有良好的可讀性和可擴充性也必不可少。本文不會深入講解過多知識性的東西,更多的則是傳授一些經驗,這些内容可能有些争議,因其受我個人喜好和coding風格影響較大,讀者可以将這部分當成是一種參考或提議,而不是作為必須遵循的準則。歸根到底,都是希望你能以一種更為合理的方式組織自己的程式。
在做深度學習實驗或項目時,為了得到最優的模型結果,中間往往需要很多次的嘗試和修改。而合理的檔案組織結構,以及一些小技巧可以極大地提高代碼的易讀易用性。根據筆者的個人經驗,在從事大多數深度學習研究時,程式都需要實作以下幾個功能:
- 模型定義
- 資料處理和加載
- 訓練模型(Train&Validate)
- 訓練過程的可視化
- 測試(Test/Inference)
另外程式還應該滿足以下幾個要求:
- 模型需具有高度可配置性,便于修改參數、修改模型,反複實驗。
- 代碼應具有良好的組織結構,使人一目了然。
- 代碼應具有良好的說明,使其他人能夠了解。
在之前的章節中,我們已經講解了PyTorch中的絕大部分内容。本章我們将應用這些内容,并結合實際的例子,來講解如何用PyTorch完成Kaggle上的經典比賽:Dogs vs. Cats。本文所有示例程式均在github上開源 。
6.1.1 比賽介紹
Dogs vs. Cats是一個傳統的二分類問題,其訓練集包含25000張圖檔,均放置在同一檔案夾下,命名格式為
<category>.<num>.jpg
, 如
cat.10000.jpg
、
dog.100.jpg
,測試集包含12500張圖檔,命名為
<num>.jpg
,如
1000.jpg
。參賽者需根據訓練集的圖檔訓練模型,并在測試集上進行預測,輸出它是狗的機率。最後送出的csv檔案如下,第一列是圖檔的
<num>
,第二列是圖檔為狗的機率。
id,label
10001,0.889
10002,0.01
…
![](https://img.laitimes.com/img/__Qf2AjLwojIjJCLyojI0JCLicmbwxCdh1mcvZ2LcV2Zh1Wa9M3clN2byBXLzN3btg3P3pVdC5GTzU1ROFzZ65EeZ1mTwEkaOxmSqlVMwkXT5VEROpXW65Ee4k3YsR2VZRHbyg1aGJjYzJEWkZHOXFWdVhUY6VzVZBHctxkeWJjWoFzVhRXUXlld4d0YxkTeMZTTINGMShUYvwlbj5yZtlmbkN3YuQnclZnbvN2Ztl2Lc9CX6MHc0RHaiojIsJye.jpg)
6.1.2 檔案組織架構
前面提到過,程式主要包含以下功能:
- 模型定義
- 資料加載
- 訓練和測試
首先來看程式檔案的組織結構:
├── checkpoints/
├── data/
│ ├── __init__.py
│ ├── dataset.py
│ └── get_data.sh
├── models/
│ ├── __init__.py
│ ├── AlexNet.py
│ ├── BasicModule.py
│ └── ResNet34.py
└── utils/
│ ├── __init__.py
│ └── visualize.py
├── config.py
├── main.py
├── requirements.txt
├── README.md
其中:
-
: 用于儲存訓練好的模型,可使程式在異常退出後仍能重新載入模型,恢複訓練。checkpoints/
-
:資料相關操作,包括資料預處理、dataset實作等。data/
-
:模型定義,可以有多個模型,例如上面的AlexNet和ResNet34,一個模型對應一個檔案。models/
-
:可能用到的工具函數,在本次實驗中主要是封裝了可視化工具。utils/
-
:配置檔案,所有可配置的變量都集中在此,并提供預設值。config.py
-
:主檔案,訓練和測試程式的入口,可通過不同的指令來指定不同的操作和參數。main.py
-
:程式依賴的第三方庫。requirements.txt
-
:提供程式的必要說明。README.md
6.1.3 關于__init__.py
可以看到,幾乎每個檔案夾下都有
__init__.py
,一個目錄如果包含了
__init__.py
檔案,那麼它就變成了一個包(package)。
__init__.py
可以為空,也可以定義包的屬性和方法,但其必須存在,其它程式才能從這個目錄中導入相應的子產品或函數。例如在
data/
檔案夾下有
__init__.py
,則在
main.py
中就可以
from data.dataset import DogCat
。而如果在
__init__.py
中寫入
from .dataset import DogCat
,則在main.py中就可以直接寫為:
from data import DogCat
,或者
import data; dataset = data.DogCat
,相比于
from data.dataset import DogCat
更加便捷。
6.1.4 資料加載
資料的相關處理主要儲存在
data/dataset.py
中。關于資料加載的相關操作,在上一章中我們已經提到過,其基本原理就是使用
Dataset
提供資料集的封裝,再使用
Dataloader
實作資料并行加載。Kaggle提供的資料包括訓練集和測試集,而我們在實際使用中,還需專門從訓練集中取出一部分作為驗證集。對于這三類資料集,其相應操作也不太一樣,而如果專門寫三個
Dataset
,則稍顯複雜和備援,是以這裡通過加一些判斷來區分。對于訓練集,我們希望做一些資料增強處理,如随機裁剪、随機翻轉、加噪聲等,而驗證集和測試集則不需要。下面看
dataset.py
的代碼:
# coding:utf8
import os
from PIL import Image
from torch.utils import data
from torchvision import transforms as T
class DogCat(data.Dataset):
def __init__(self, root, transforms=None, train=True, test=False):
"""
主要目标: 擷取所有圖檔的位址,并根據訓練,驗證,測試劃分資料
"""
self.test = test
imgs = [os.path.join(root, img) for img in os.listdir(root)]
# test1: data/test1/8973.jpg
# train: data/train/cat.10004.jpg
if self.test:
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2].split('/')[-1]))
else:
imgs = sorted(imgs, key=lambda x: int(x.split('.')[-2]))
imgs_num = len(imgs)
if self.test:
self.imgs = imgs
elif train:
self.imgs = imgs[:int(0.7 * imgs_num)]
else:
self.imgs = imgs[int(0.7 * imgs_num):]
if transforms is None:
normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
if self.test or not train:
self.transforms = T.Compose([
T.Resize(224),
T.CenterCrop(224),
T.ToTensor(),
normalize
])
else:
self.transforms = T.Compose([
T.Resize(256),
T.RandomCrop(224),
T.RandomHorizontalFlip(),
T.ToTensor(),
normalize
])
def __getitem__(self, index):
"""
一次傳回一張圖檔的資料
"""
img_path = self.imgs[index]
if self.test:
label = int(self.imgs[index].split('.')[-2].split('/')[-1])
else:
label = 1 if 'dog' in img_path.split('/')[-1] else 0
data = Image.open(img_path)
data = self.transforms(data)
return data, label
def __len__(self):
return len(self.imgs)
關于資料集使用的注意事項,在上一章中已經提到,将檔案讀取等費時操作放在
__getitem__
函數中,利用多程序加速。避免一次性将所有圖檔都讀進記憶體,不僅費時也會占用較大記憶體,而且不易進行資料增強等操作。另外在這裡,我們将訓練集中的30%作為驗證集,可用來檢查模型的訓練效果,避免過拟合。在使用時,我們可通過dataloader加載資料。
train_dataset = DogCat(opt.train_data_root, train=True)
trainloader = DataLoader(train_dataset,
batch_size = opt.batch_size,
shuffle = True,
num_workers = opt.num_workers)
for ii, (data, label) in enumerate(trainloader):
train()
6.1.5 模型定義
模型的定義主要儲存在
models/
目錄下,其中
BasicModule
是對
nn.Module
的簡易封裝,提供快速加載和儲存模型的接口。
# coding:utf8
import time
import torch as t
class BasicModule(t.nn.Module):
"""
封裝了nn.Module,主要是提供了save和load兩個方法
"""
def __init__(self):
super(BasicModule, self).__init__()
self.model_name = str(type(self)) # 預設名字
def load(self, path):
"""
可加載指定路徑的模型
"""
self.load_state_dict(t.load(path))
def save(self, name=None):
"""
儲存模型,預設使用“模型名字+時間”作為檔案名
"""
if name is None:
prefix = 'checkpoints/' + self.model_name + '_'
name = time.strftime(prefix + '%Y%m%d%H%M%S.pth')
t.save(self.state_dict(), name)
return name
def get_optimizer(self, lr, weight_decay):
return t.optim.Adam(self.parameters(), lr=lr, weight_decay=weight_decay)
class Flat(t.nn.Module):
"""
把輸入reshape成(batch_size,dim_length)
"""
def __init__(self):
super(Flat, self).__init__()
# self.size = size
def forward(self, x):
return x.view(x.size(0), -1)
在實際使用中,直接調用
model.save()
及
model.load(opt.load_path)
即可。
其它自定義模型一般繼承
BasicModule
,然後實作自己的模型。其中
AlexNet.py
實作了AlexNet,
ResNet34
實作了ResNet34。在
models/__init__py
中,代碼如下:
from .AlexNet import AlexNet
from .ResNet34 import ResNet34
這樣在主函數中就可以寫成:
from models import AlexNet
或
import models
model = models.AlexNet()
或
import models
model = getattr('models', 'AlexNet')()
其中最後一種寫法最為關鍵,這意味着我們可以通過字元串直接指定使用的模型,而不必使用判斷語句,也不必在每次新增加模型後都修改代碼。新增模型後隻需要在
models/__init__.py
中加上
from .new_module import new_module
即可。
其它關于模型定義的注意事項,在上一章中已詳細講解,這裡就不再贅述,總結起來就是:
- 盡量使用
(比如AlexNet)。nn.Sequential
- 将經常使用的結構封裝成子Module(比如GoogLeNet的Inception結構,ResNet的Residual Block結構)。
- 将重複且有規律性的結構,用函數生成(比如VGG的多種變體,ResNet多種變體都是由多個重複卷積層組成)。
6.1.6 工具函數
在項目中,我們可能會用到一些helper方法,這些方法可以統一放在
utils/
檔案夾下,需要使用時再引入。在本例中主要是封裝了可視化工具visdom的一些操作,其代碼如下,在本次實驗中隻會用到
plot
方法,用來統計損失資訊。
# coding:utf8
import time
import numpy as np
import visdom
class Visualizer(object):
"""
封裝了visdom的基本操作,但是你仍然可以通過`self.vis.function`
調用原生的visdom接口
"""
def __init__(self, env='default', **kwargs):
self.vis = visdom.Visdom(env=env, use_incoming_socket=False, **kwargs)
# 畫的第幾個數,相當于橫座标
# 儲存(’loss',23) 即loss的第23個點
self.index = {}
self.log_text = ''
def reinit(self, env='default', **kwargs):
"""
修改visdom的配置
"""
self.vis = visdom.Visdom(env=env, **kwargs)
return self
def plot_many(self, d):
"""
一次plot多個
@params d: dict (name,value) i.e. ('loss',0.11)
"""
for k, v in d.items():
self.plot(k, v)
def img_many(self, d):
for k, v in d.items():
self.img(k, v)
def plot(self, name, y, **kwargs):
"""
self.plot('loss',1.00)
"""
x = self.index.get(name, 0)
self.vis.line(Y=np.array([y]), X=np.array([x]),
win=name,
opts=dict(title=name),
update=None if x == 0 else 'append',
**kwargs
)
self.index[name] = x + 1
def img(self, name, img_, **kwargs):
"""
self.img('input_img',t.Tensor(64,64))
self.img('input_imgs',t.Tensor(3,64,64))
self.img('input_imgs',t.Tensor(100,1,64,64))
self.img('input_imgs',t.Tensor(100,3,64,64),nrows=10)
!!!don‘t ~~self.img('input_imgs',t.Tensor(100,64,64),nrows=10)~~!!!
"""
self.vis.images(img_.cpu().numpy(),
win=name,
opts=dict(title=name),
**kwargs
)
def log(self, info, win='log_text'):
"""
self.log({'loss':1,'lr':0.0001})
"""
self.log_text += ('[{time}] {info} <br>'.format(
time=time.strftime('%Y%m%d %H:%M:%S'),
info=info))
self.vis.text(self.log_text, win)
def __getattr__(self, name):
return getattr(self.vis, name)
6.1.7 配置檔案
在模型定義、資料處理和訓練等過程都有很多變量,這些變量應提供預設值,并統一放置在配置檔案中,這樣在後期調試、修改代碼或遷移程式時會比較友善,在這裡我們将所有可配置項放在
config.py
中。
# coding:utf8
import warnings
import torch as t
class DefaultConfig(object):
env = 'default' # visdom 環境
vis_port = 8097 # visdom 端口
model = 'SqueezeNet' # 使用的模型,名字必須與models/__init__.py中的名字一緻
train_data_root = './data/train/' # 訓練集存放路徑
test_data_root = './data/test/' # 測試集存放路徑
load_model_path = None # 加載預訓練的模型的路徑,為None代表不加載
batch_size = 32 # batch size
use_gpu = True # user GPU or not
num_workers = 0 # how many workers for loading data
print_freq = 20 # print info every N batch
debug_file = './debug/debug.txt' # if os.path.exists(debug_file): enter ipdb
result_file = 'result.csv'
max_epoch = 10
lr = 0.001 # initial learning rate
lr_decay = 0.5 # when val_loss increase, lr = lr*lr_decay
weight_decay = 0e-5 # 損失函數
def _parse(self, kwargs):
"""
根據字典kwargs 更新 config參數
"""
for k, v in kwargs.items():
if not hasattr(self, k):
warnings.warn("Warning: opt has not attribut %s" % k)
setattr(self, k, v)
opt.device = t.device('cuda') if opt.use_gpu else t.device('cpu')
print('user config:')
for k, v in self.__class__.__dict__.items():
if not k.startswith('_'):
print(k, getattr(self, k))
opt = DefaultConfig()
可配置的參數主要包括:
- 資料集參數(檔案路徑、batch_size等)
- 訓練參數(學習率、訓練epoch等)
- 模型參數
這樣我們在程式中就可以這樣使用:
import models
from config import DefaultConfig
opt = DefaultConfig()
lr = opt.lr
model = getattr(models, opt.model)
dataset = DogCat(opt.train_data_root)
這些都隻是預設參數,在這裡還提供了更新函數,根據字典更新配置參數。
def _parse(self, kwargs):
"""
根據字典kwargs 更新 config參數
"""
for k, v in kwargs.items():
if not hasattr(self, k):
warnings.warn("Warning: opt has not attribut %s" % k)
setattr(self, k, v)
opt.device = t.device('cuda') if opt.use_gpu else t.device('cpu')
print('user config:')
for k, v in self.__class__.__dict__.items():
if not k.startswith('_'):
print(k, getattr(self, k))
這樣我們在實際使用時,并不需要每次都修改
config.py
,隻需要通過指令行傳入所需參數,覆寫預設配置即可。
例如:
opt = DefaultConfig()
new_config = {'lr':0.1,'use_gpu':False}
opt.parse(new_config)
opt.lr == 0.1
6.1.8 main.py
在講解主程式
main.py
之前,我們先來看看2017年3月谷歌開源的一個指令行工具fire,通過
pip install fire
即可安裝。下面來看看
fire
的基礎用法,假設
example.py
檔案内容如下:
import fire
def add(x, y):
return x + y
def mul(**kwargs):
a = kwargs['a']
b = kwargs['b']
return a * b
if __name__ == '__main__':
fire.Fire()
那麼我們可以使用:
python example.py add 1 2 # 執行add(1, 2)
python example.py mul --a=1 --b=2 # 執行mul(a=1, b=2), kwargs={'a':1, 'b':2}
python example.py add --x=1 --y==2 # 執行add(x=1, y=2)
可見,隻要在程式中運作
fire.Fire()
,即可使用指令行參數
python file <function> [args,] {--kwargs,}
。fire還支援更多的進階功能,具體請參考官方指南《The Python Fire Guide》。
在主程式
main.py
中,主要包含四個函數,其中三個需要指令行執行,
main.py
的代碼組織結構如下:
def train(**kwargs):
"""
訓練
"""
pass
def val(model, dataloader):
"""
計算模型在驗證集上的準确率等資訊,用以輔助訓練
"""
pass
def test(**kwargs):
"""
測試(inference)
"""
pass
def help():
"""
列印幫助的資訊
"""
print('help')
if __name__=='__main__':
import fire
fire.Fire()
根據fire的使用方法,可通過
python main.py <function> --args=xx
的方式來執行訓練或者測試。
訓練
訓練的主要步驟如下:
- 定義網絡
- 定義資料
- 定義損失函數和優化器
- 計算重要名額
- 開始訓練
- 訓練網絡
- 可視化各種名額
- 計算在驗證集上的名額
訓練函數的代碼如下:
def train(**kwargs):
opt._parse(kwargs)
vis = Visualizer(opt.env, port=opt.vis_port)
# step1: configure model
model = getattr(models, opt.model)()
if opt.load_model_path:
model.load(opt.load_model_path)
model.to(opt.device)
# step2: data
train_data = DogCat(opt.train_data_root, train=True)
val_data = DogCat(opt.train_data_root, train=False)
train_dataloader = DataLoader(train_data, opt.batch_size,
shuffle=True, num_workers=opt.num_workers)
val_dataloader = DataLoader(val_data, opt.batch_size,
shuffle=False, num_workers=opt.num_workers)
# step3: criterion and optimizer
criterion = t.nn.CrossEntropyLoss()
lr = opt.lr
optimizer = model.get_optimizer(lr, opt.weight_decay)
# step4: meters
loss_meter = meter.AverageValueMeter()
confusion_matrix = meter.ConfusionMeter(2)
previous_loss = 1e10
# train
for epoch in range(opt.max_epoch):
loss_meter.reset()
confusion_matrix.reset()
for ii, (data, label) in tqdm(enumerate(train_dataloader)):
# train model
input = data.to(opt.device)
target = label.to(opt.device)
optimizer.zero_grad()
score = model(input)
loss = criterion(score, target)
loss.backward()
optimizer.step()
# meters update and visualize
loss_meter.add(loss.item())
# detach 一下更安全保險
confusion_matrix.add(score.detach(), target.detach())
if (ii + 1) % opt.print_freq == 0:
vis.plot('loss', loss_meter.value()[0])
print("loss:", loss_meter.value()[0])
# 進入debug模式
# if os.path.exists(opt.debug_file):
# import ipdb;
# ipdb.set_trace()
print("儲存檢查點...")
model.save()
cm_value = confusion_matrix.value()
vis.plot('train_accuracy', 100. * (cm_value[0][0] + cm_value[1][1]) / cm_value.sum())
# validate and visualize
val_cm, val_accuracy = val(model, val_dataloader)
vis.plot('val_accuracy', val_accuracy)
vis.log("\tepoch:{epoch},\tlr:{lr},\tloss:{loss},\ttrain_cm:{train_cm},\tval_cm:{val_cm}\t".format(
epoch=epoch, lr=lr, loss=loss_meter.value()[0], train_cm=str(confusion_matrix.value()),
val_cm=str(val_cm.value())))
# update learning rate
if loss_meter.value()[0] > previous_loss:
lr = lr * opt.lr_decay
# 第二種降低學習率的方法:不會有moment等資訊的丢失
for param_group in optimizer.param_groups:
param_group['lr'] = lr
previous_loss = loss_meter.value()[0]
這裡用到了PyTorchNet裡面的一個工具: meter。meter提供了一些輕量級的工具,用于幫助使用者快速統計訓練過程中的一些名額。
AverageValueMeter
能夠計算所有數的平均值和标準差,這裡用來統計一個epoch中損失的平均值。
confusionmeter
用來統計分類問題中的分類情況,是一個比準确率更詳細的統計名額。例如對于表格6-1,共有50張狗的圖檔,其中有35張被正确分類成了狗,還有15張被誤判成貓;共有100張貓的圖檔,其中有91張被正确判為了貓,剩下9張被誤判成狗。相比于準确率等統計資訊,混淆矩陣更能展現分類的結果,尤其是在樣本比例不均衡的情況下。
表6-1 混淆矩陣
樣本 | 判為狗 | 判為貓 |
---|---|---|
實際是狗 | 35 | 15 |
實際是貓 | 9 | 91 |
PyTorchNet從TorchNet遷移而來,提供了很多有用的工具,但其目前開發和文檔都還不是很完善,本書不做過多的講解。
驗證
驗證相對來說比較簡單,但要注意需将模型置于驗證模式(
model.eval()
),驗證完成後還需要将其置回為訓練模式(
model.train()
),這兩句代碼會影響
BatchNorm
和
Dropout
等層的運作模式。驗證模型準确率的代碼如下。
@t.no_grad()
def val(model, dataloader):
"""
計算模型在驗證集上的準确率等資訊
"""
model.eval()
confusion_matrix = meter.ConfusionMeter(2)
for ii, (val_input, label) in tqdm(enumerate(dataloader)):
val_input = val_input.to(opt.device)
score = model(val_input)
confusion_matrix.add(score.detach().squeeze(), label.type(t.LongTensor))
model.train()
cm_value = confusion_matrix.value()
accuracy = 100. * (cm_value[0][0] + cm_value[1][1]) / (cm_value.sum())
return confusion_matrix, accuracy
測試
測試時,需要計算每個樣本屬于狗的機率,并将結果儲存成csv檔案。測試的代碼與驗證比較相似,但需要自己加載模型和資料。
@t.no_grad() # pytorch>=0.5
def test(**kwargs):
opt._parse(kwargs)
# configure model
model = getattr(models, opt.model)().eval()
if opt.load_model_path:
model.load(opt.load_model_path)
model.to(opt.device)
# data
train_data = DogCat(opt.test_data_root, test=True)
test_dataloader = DataLoader(train_data, batch_size=opt.batch_size, shuffle=False, num_workers=opt.num_workers)
results = []
for ii, (data, path) in tqdm(enumerate(test_dataloader)):
input = data.to(opt.device)
score = model(input)
probability = t.nn.functional.softmax(score, dim=1)[:, 0].detach().tolist()
batch_results = [(path_.item(), probability_) for path_, probability_ in zip(path, probability)]
results += batch_results
write_csv(results, opt.result_file)
return results
def write_csv(results, file_name):
import csv
with open(file_name, 'w') as f:
writer = csv.writer(f)
writer.writerow(['id', 'label'])
writer.writerows(results)
幫助函數
為了友善他人使用, 程式中還應當提供一個幫助函數,用于說明函數是如何使用。程式的指令行接口中有衆多參數,如果手動用字元串表示不僅複雜,而且後期修改config檔案時,還需要修改對應的幫助資訊,十分不便。這裡使用了Python标準庫中的inspect方法,可以自動擷取config的源代碼。help的代碼如下:
def help():
"""
列印幫助的資訊: python file.py help
"""
print("""
usage : python file.py <function> [--args=value]
<function> := train | test | help
example:
python {0} train --env='env0701' --lr=0.01
python {0} test --dataset='path/to/dataset/root/'
python {0} help
avaiable args:""".format(__file__))
from inspect import getsource
source = (getsource(opt.__class__))
print(source)
當使用者執行
python main.py help
的時候,會列印如下幫助資訊:
usage : python main.py <function> [--args=value,]
<function> := train | test | help
example:
python main.py train --env='env0701' --lr=0.01
python main.py test --dataset='path/to/dataset/'
python main.py help
avaiable args:
class DefaultConfig(object):
env = 'default' # visdom 環境
model = 'AlexNet' # 使用的模型
train_data_root = './data/train/' # 訓練集存放路徑
test_data_root = './data/test' # 測試集存放路徑
load_model_path = 'checkpoints/model.pth' # 加載預訓練的模型
batch_size = 128 # batch size
use_gpu = True # user GPU or not
num_workers = 4 # how many workers for loading data
print_freq = 20 # print info every N batch
debug_file = './debug/debug.txt'
result_file = 'result.csv' # 結果檔案
max_epoch = 10
lr = 0.1 # initial learning rate
lr_decay = 0.95 # when val_loss increase, lr = lr*lr_decay
weight_decay = 1e-4 # 損失函數
6.1.9 使用
正如
help
函數的列印資訊所述,可以通過指令行參數指定變量名.下面是三個使用例子,fire會将包含
-
的指令行參數自動轉層下劃線
_
,也會将非數值的值轉成字元串。是以
--train-data-root=data/train
和
--train_data_root='data/train'
是等價的。
# 訓練模型
python main.py train
--train-data-root=data/train/
--lr=0.005
--batch-size=32
--model='ResNet34'
--max-epoch = 20
# 測試模型
python main.py test
--test-data-root=data/test
--load-model-path='checkpoints/resnet34_00:23:05.pth'
--batch-size=128
--model='ResNet34'
--num-workers=12
# 列印幫助資訊
python main.py help
實驗過程
本章程式及資料下載下傳:百度網盤,提取碼:aw26。
首先,在指令行cmd紅啟動visdom伺服器:
python -m visdom.server
然後,訓練模型:
python main.py train
訓練結果如下:
從上述結果可以看出,模型的精度可以達到97%以上。你也可以手動更改模型,通過調節參數來進一步提升模型的準确率。
最後,測試模型:
python main.py test
第二清單示預測為狗的機率:
我們來看一下測試集圖檔:
可以看到,模型能夠正确識别出很多狗和貓了,但是還存在很大的改進空間。
6.1.10 争議
以上的程式設計規範帶有作者強烈的個人喜好,并不想作為一個标準,而是作為一個提議和一種參考。上述設計在很多地方還有待商榷,例如對于訓練過程是否應該封裝成一個
trainer
對象,或者直接封裝到
BaiscModule
的
train
方法之中。對指令行參數的處理也有不少值得讨論之處。是以不要将本文中的觀點作為一個必須遵守的規範,而應該看作一個參考。
本章中的設計可能會引起不少争議,其中比較值得商榷的部分主要有以下兩個方面:
- 指令行參數的設定。目前大多數程式都是使用Python标準庫中的
來處理指令行參數,也有些使用比較輕量級的argparse
。這種處理相對來說對指令行的支援更完備,但根據作者的經驗來看,這種做法不夠直覺,并且代碼量相對來說也較多。比如click
,每次增加一個指令行參數,都必須寫如下代碼:argparse
在讀者眼中,這種實作方式遠不如一個專門的
config.py
來的直覺和易用。尤其是對于使用Jupyter notebook或IPython等互動式調試的使用者來說,
argparse
較難使用。
- 模型訓練。有不少人喜歡将模型的訓練過程內建于模型的定義之中,代碼結構如下所示:
class MyModel(nn.Module):
def __init__(self,opt):
self.dataloader = Dataloader(opt)
self.optimizer = optim.Adam(self.parameters(),lr=0.001)
self.lr = opt.lr
self.model = make_model()
def forward(self,input):
pass
def train_(self):
# 訓練模型
for epoch in range(opt.max_epoch)
for ii,data in enumerate(self.dataloader):
train_epoch()
model.save()
def train_epoch(self):
pass
抑或是專門設計一個
Trainer
對象,形如:
"""
code simplified from:
https://github.com/pytorch/pytorch/blob/master/torch/utils/trainer/trainer.py
"""
import heapq
from torch.autograd import Variable
class Trainer(object):
def __init__(self, model=None, criterion=None, optimizer=None, dataset=None):
self.model = model
self.criterion = criterion
self.optimizer = optimizer
self.dataset = dataset
self.iterations = 0
def run(self, epochs=1):
for i in range(1, epochs + 1):
self.train()
def train(self):
for i, data in enumerate(self.dataset, self.iterations + 1):
batch_input, batch_target = data
self.call_plugins('batch', i, batch_input, batch_target)
input_var = Variable(batch_input)
target_var = Variable(batch_target)
plugin_data = [None, None]
def closure():
batch_output = self.model(input_var)
loss = self.criterion(batch_output, target_var)
loss.backward()
if plugin_data[0] is None:
plugin_data[0] = batch_output.data
plugin_data[1] = loss.data
return loss
self.optimizer.zero_grad()
self.optimizer.step(closure)
self.iterations += i
還有一些人喜歡模仿keras和scikit-learn的設計,設計一個
fit
接口。對讀者來說,這些處理方式很難說哪個更好或更差,找到最适合自己的方法才是最好的。
BasicModule
的封裝,可多可少。訓練過程中的很多操作都可以移到
BasicModule
之中,比如
get_optimizer
方法用來擷取優化器,比如
train_step
用來執行單歩訓練。對于不同的模型,如果對應的優化器定義不一樣,或者是訓練方法不一樣,可以複寫這些函數自定義相應的方法,取決于自己的喜好和項目的實際需求。
6.2 PyTorch Debug指南
6.2.1 ipdb介紹
很多初學者用print或log調試程式,這在小規模的程式下很友善。但是更好的調試方法是一邊運作一邊檢查裡面的變量和方法。pdb是一個互動式的調試工具,內建于Python的标準庫之中,由于其強大的功能,被廣泛應用于Python環境中。pdb能讓你根據需求跳轉到任意的Python代碼斷點、檢視任意變量、單步執行代碼,甚至還能修改代碼的值,而不必重新開機程式。ipdb是一個增強版的pdb,可通過
pip install ipdb
安裝。ipdb提供了調試模式下的代碼補全,還具有更好的文法高亮和代碼溯源,以及更好的内省功能,更關鍵的是,它與pdb接口完全相容。
在本書第2章曾粗略地提到過ipdb的基本使用,本章将繼續介紹如何結合PyTorch和ipdb進行調試。首先看一個例子,要是用ipdb,隻需在想要進行調試的地方插入
ipdb.set_trace()
,當代碼運作到此處時,就會自動進入互動式調試模式。
假設有如下程式:
try:
import ipdb
except:
import pdb as ipdb
def sum(x):
r = 0
for ii in x:
r += ii
return r
def mul(x):
r = 1
for ii in x:
r *= ii
return r
ipdf.set_trace()
x = [1,2,3,4,5]
r = sum(x)
r = mul(x)
當程式運作至ipdb.set_trace(),會自動進入debug模式,在該模式中,我們可使用調試指令,如next或縮寫n單步執行,也可檢視Python變量,或是運作Python代碼。如果Python變量名和調式指令沖突,需要在變量名前加"!",這樣ipdb會執行對應的Python代碼,而不是調試指令。下面舉例說明ipdb的調試,這裡重點講解ipdb的兩大功能。
- 檢視:在函數調用堆棧中自由跳轉,并檢視函數的局部變量
- 修改:修改程式中的變量,并能以此影響程式的運作結果。
> e:\debug.py(19)<module>()
18 ipdb.set_trace()
---> 19 x = [1,2,3,4,5]
20 r = sum(x)
ipdb> l 1,21 # list 1,21的縮寫,檢視第1行到第21行的代碼,光标所指的這一行尚未運作
1 try:
2 import ipdb
3 except:
4 import pdb as ipdb
5
6 def sum(x):
7 r = 0
8 for ii in x:
9 r += ii
10 return r
11
12 def mul(x):
13 r = 1
14 for ii in x:
15 r *= ii
16 return r
17
18 ipdb.set_trace()
---> 19 x = [1,2,3,4,5]
20 r = sum(x)
21 r = mul(x)
ipdb> n # next的縮寫,執行下一步
> e:\debug.py(20)<module>()
19 x = [1,2,3,4,5]
---> 20 r = sum(x)
21 r = mul(x)
ipdb> s # step的縮寫,進入sum函數内部
--Call--
> e:\debug.py(6)sum()
5
----> 6 def sum(x):
7 r = 0
ipdb> n # next單步執行
> e:\debug.py(7)sum()
6 def sum(x):
----> 7 r = 0
8 for ii in x:
ipdb> n # next單步執行
> e:\debug.py(8)sum()
7 r = 0
----> 8 for ii in x:
9 r += ii
ipdb> n # next單步執行
> e:\debug.py(9)sum()
8 for ii in x:
----> 9 r += ii
10 return r
ipdb> u # up的縮寫,跳回到上一層的調用
> e:\debug.py(20)<module>()
19 x = [1,2,3,4,5]
---> 20 r = sum(x)
21 r = mul(x)
ipdb> d # down的縮寫,跳到調用的下一層
> e:\debug.py(9)sum()
8 for ii in x:
----> 9 r += ii
10 return r
ipdb> !r # !r 檢視變量r的值,該變量名與調試指令`r(eturn)`沖突
0
ipdb> r # return的縮寫,繼續運作直到函數傳回
--Return--
15
> e:\debug.py(10)sum()
9 r += ii
---> 10 return r
11
ipdb> n # 下一步
> e:\debug.py(21)<module>()
19 x = [1,2,3,4,5]
20 r = sum(x)
---> 21 r = mul(x)
ipdb> x # 檢視變量x的值
[1, 2, 3, 4, 5]
ipdb> x[0] = 10000 # 修改變量x的值
ipdb> b 13 # break的縮寫,AI第13行設定斷點
Breakpoint 1 at e:\debug.py:13
ipdb> c # continue的縮寫,繼續運作,直到遇到斷點
> e:\debug.py(13)mul()
12 def mul(x):
1--> 13 r = 1
14 for ii in x:
ipdb> return # 傳回的是修改後x的乘積
--Return--
1200000
> e:\debug.py(16)mul()
15 r *= ii
---> 16 return r
17
ipdb> q # quit的縮寫,退出debug模式
Exiting Debugger.
關于ipdb的使用還有一些技巧:
- 鍵能夠自動補齊,補齊用法與IPython中的類似。
- j(ump) 能夠跳過中間某些行代碼的執行
- 可以直接在ipdb中修改變量的值
- h(elp)能夠檢視調試指令的用法,比如
可以檢視h(elp)指令的用法,h h
能夠檢視j(ump)指令的用法。h jump
6.2.2 在PyTorch中Debug
PyTorch作為一個動态圖架構,與ipdb結合使用能為調試過程帶來便捷。對TensorFlow等靜态圖架構來說,使用Python接口定義計算圖,然後使用C++代碼執行底層運算,在定義圖的時候不進行任何計算,而在計算的時候又無法使用pdb進行調試,因為pdb調試隻能調試Python代碼,故調試一直是此類靜态圖架構的一個痛點。與TensorFlow不同,PyTorch可以在執行計算的同時定義計算圖,這些計算定義過程是使用Python完成的。雖然底層的計算也是用C/C++完成的,但是我們能夠檢視Python定義部分的變量值,這就已經足夠了。下面我們将舉例說明。
- 如何AIPyTorch中檢視神經網絡各個層的輸出。
- 如何在PyTorch中分析各個參數的梯度。
- 如何動态修改PyTorch的訓練過程。
首先,運作第一節給出的“貓狗大戰”程式:
python main.py train --debug-file='debug/debug.txt'
程式運作一段時間後,在debug目錄下建立debug.txt辨別檔案,當程式檢測到這個檔案存在時,會自動進入debug模式。
99it [00:17, 6.07it/s]loss: 0.22854854568839075
119it [00:21, 5.79it/s]loss: 0.21267264398435753
139it [00:24, 5.99it/s]loss: 0.19839374726372108
> e:\workspace\python\pytorch\chapter6\main.py(80)train()
79 loss_meter.reset()
---> 80 confusion_matrix.reset()
81 for ii, (data, label) in tqdm(enumerate(train_dataloader)):
ipdb> break 88 # 在第88行設定斷點,當程式運作到此處進入debug模式
Breakpoint 1 at e:\workspace\python\pytorch\chapter6\main.py:88
ipdb> # 列印所有參數及其梯度的标準差
for (name,p) in model.named_parameters(): \
print(name,p.data.std(),p.grad.data.std())
model.features.0.weight tensor(0.2615, device='cuda:0') tensor(0.3769, device='cuda:0')
model.features.0.bias tensor(0.4862, device='cuda:0') tensor(0.3368, device='cuda:0')
model.features.3.squeeze.weight tensor(0.2738, device='cuda:0') tensor(0.3023, device='cuda:0')
model.features.3.squeeze.bias tensor(0.5867, device='cuda:0') tensor(0.3753, device='cuda:0')
model.features.3.expand1x1.weight tensor(0.2168, device='cuda:0') tensor(0.2883, device='cuda:0')
model.features.3.expand1x1.bias tensor(0.2256, device='cuda:0') tensor(0.1147, device='cuda:0')
model.features.3.expand3x3.weight tensor(0.0935, device='cuda:0') tensor(0.1605, device='cuda:0')
model.features.3.expand3x3.bias tensor(0.1421, device='cuda:0') tensor(0.0583, device='cuda:0')
model.features.4.squeeze.weight tensor(0.1976, device='cuda:0') tensor(0.2137, device='cuda:0')
model.features.4.squeeze.bias tensor(0.4058, device='cuda:0') tensor(0.1798, device='cuda:0')
model.features.4.expand1x1.weight tensor(0.2144, device='cuda:0') tensor(0.4214, device='cuda:0')
model.features.4.expand1x1.bias tensor(0.4994, device='cuda:0') tensor(0.0958, device='cuda:0')
model.features.4.expand3x3.weight tensor(0.1063, device='cuda:0') tensor(0.2963, device='cuda:0')
model.features.4.expand3x3.bias tensor(0.0489, device='cuda:0') tensor(0.0719, device='cuda:0')
model.features.6.squeeze.weight tensor(0.1736, device='cuda:0') tensor(0.3544, device='cuda:0')
model.features.6.squeeze.bias tensor(0.2420, device='cuda:0') tensor(0.0896, device='cuda:0')
model.features.6.expand1x1.weight tensor(0.1211, device='cuda:0') tensor(0.2428, device='cuda:0')
model.features.6.expand1x1.bias tensor(0.0670, device='cuda:0') tensor(0.0162, device='cuda:0')
model.features.6.expand3x3.weight tensor(0.0593, device='cuda:0') tensor(0.1917, device='cuda:0')
model.features.6.expand3x3.bias tensor(0.0227, device='cuda:0') tensor(0.0160, device='cuda:0')
model.features.7.squeeze.weight tensor(0.1207, device='cuda:0') tensor(0.2179, device='cuda:0')
model.features.7.squeeze.bias tensor(0.1484, device='cuda:0') tensor(0.0381, device='cuda:0')
model.features.7.expand1x1.weight tensor(0.1235, device='cuda:0') tensor(0.2279, device='cuda:0')
model.features.7.expand1x1.bias tensor(0.0450, device='cuda:0') tensor(0.0100, device='cuda:0')
model.features.7.expand3x3.weight tensor(0.0609, device='cuda:0') tensor(0.1628, device='cuda:0')
model.features.7.expand3x3.bias tensor(0.0132, device='cuda:0') tensor(0.0079, device='cuda:0')
model.features.9.squeeze.weight tensor(0.1093, device='cuda:0') tensor(0.2459, device='cuda:0')
model.features.9.squeeze.bias tensor(0.0646, device='cuda:0') tensor(0.0135, device='cuda:0')
model.features.9.expand1x1.weight tensor(0.0840, device='cuda:0') tensor(0.1860, device='cuda:0')
model.features.9.expand1x1.bias tensor(0.0177, device='cuda:0') tensor(0.0033, device='cuda:0')
model.features.9.expand3x3.weight tensor(0.0476, device='cuda:0') tensor(0.1393, device='cuda:0')
model.features.9.expand3x3.bias tensor(0.0058, device='cuda:0') tensor(0.0030, device='cuda:0')
model.features.10.squeeze.weight tensor(0.0872, device='cuda:0') tensor(0.1676, device='cuda:0')
model.features.10.squeeze.bias tensor(0.0484, device='cuda:0') tensor(0.0088, device='cuda:0')
model.features.10.expand1x1.weight tensor(0.0859, device='cuda:0') tensor(0.2145, device='cuda:0')
model.features.10.expand1x1.bias tensor(0.0160, device='cuda:0') tensor(0.0025, device='cuda:0')
model.features.10.expand3x3.weight tensor(0.0456, device='cuda:0') tensor(0.1429, device='cuda:0')
model.features.10.expand3x3.bias tensor(0.0070, device='cuda:0') tensor(0.0021, device='cuda:0')
model.features.11.squeeze.weight tensor(0.0786, device='cuda:0') tensor(0.2003, device='cuda:0')
model.features.11.squeeze.bias tensor(0.0422, device='cuda:0') tensor(0.0069, device='cuda:0')
model.features.11.expand1x1.weight tensor(0.0690, device='cuda:0') tensor(0.1400, device='cuda:0')
model.features.11.expand1x1.bias tensor(0.0138, device='cuda:0') tensor(0.0022, device='cuda:0')
model.features.11.expand3x3.weight tensor(0.0366, device='cuda:0') tensor(0.1517, device='cuda:0')
model.features.11.expand3x3.bias tensor(0.0109, device='cuda:0') tensor(0.0023, device='cuda:0')
model.features.12.squeeze.weight tensor(0.0729, device='cuda:0') tensor(0.1736, device='cuda:0')
model.features.12.squeeze.bias tensor(0.0814, device='cuda:0') tensor(0.0084, device='cuda:0')
model.features.12.expand1x1.weight tensor(0.0977, device='cuda:0') tensor(0.1385, device='cuda:0')
model.features.12.expand1x1.bias tensor(0.0102, device='cuda:0') tensor(0.0032, device='cuda:0')
model.features.12.expand3x3.weight tensor(0.0365, device='cuda:0') tensor(0.1312, device='cuda:0')
model.features.12.expand3x3.bias tensor(0.0038, device='cuda:0') tensor(0.0026, device='cuda:0')
model.classifier.1.weight tensor(0.0285, device='cuda:0') tensor(0.0865, device='cuda:0')
model.classifier.1.bias tensor(0.0362, device='cuda:0') tensor(0.0192, device='cuda:0')
ipdb> opt.lr # 檢視學習率
0.001
ipdb> opt.lr = 0.002 # 更改學習率
ipdb> for p in optimizer.param_groups: \
p['lr'] = opt.lr
ipdb> model.save() # 儲存模型
'checkpoints/squeezenet_20191004212249.pth'
ipdb> c # 繼續運作,直到第88行暫停
222it [16:38, 35.62s/it]> e:\workspace\python\pytorch\chapter6\main.py(88)train()
87 optimizer.zero_grad()
1--> 88 score = model(input)
89 loss = criterion(score, target)
ipdb> s # 進入model(input)内部,即model.__call__(input)
--Call--
> c:\programdata\anaconda3\lib\site-packages\torch\nn\modules\module.py(537)__call__()
536
--> 537 def __call__(self, *input, **kwargs):
538 for hook in self._forward_pre_hooks.values():
ipdb> n # 下一步
> c:\programdata\anaconda3\lib\site-packages\torch\nn\modules\module.py(538)__call__()
537 def __call__(self, *input, **kwargs):
--> 538 for hook in self._forward_pre_hooks.values():
539 result = hook(self, input)
ipdb> n # 下一步
> c:\programdata\anaconda3\lib\site-packages\torch\nn\modules\module.py(544)__call__()
543 input = result
--> 544 if torch._C._get_tracing_state():
545 result = self._slow_forward(*input, **kwargs)
ipdb> n # 下一步
> c:\programdata\anaconda3\lib\site-packages\torch\nn\modules\module.py(547)__call__()
546 else:
--> 547 result = self.forward(*input, **kwargs)
548 for hook in self._forward_hooks.values():
ipdb> s # 進入forward函數内容
--Call--
> c:\programdata\anaconda3\lib\site-packages\torch\nn\modules\loss.py(914)forward()
913
--> 914 def forward(self, input, target):
915 return F.cross_entropy(input, target, weight=self.weight,
ipdb> input # 檢視input變量值
tensor([[4.5005, 2.0725],
[3.5933, 7.8643],
[2.9086, 3.4209],
[2.7740, 4.4332],
[6.0164, 2.3033],
[5.2261, 3.2189],
[2.6529, 2.0749],
[6.3259, 2.2383],
[3.0629, 3.4832],
[2.7008, 8.2818],
[5.5684, 2.1567],
[3.0689, 6.1022],
[3.4848, 5.3831],
[1.7920, 5.7709],
[6.5032, 2.8080],
[2.3071, 5.2417],
[3.7474, 5.0263],
[4.3682, 3.6707],
[2.2196, 6.9298],
[5.2201, 2.3034],
[6.4315, 1.4970],
[3.4684, 4.0371],
[3.9620, 1.7629],
[1.7069, 7.8898],
[3.0462, 1.6505],
[2.4081, 6.4456],
[2.1932, 7.4614],
[2.3405, 2.7603],
[1.9478, 8.4156],
[2.7935, 7.8331],
[1.8898, 3.8836],
[3.3008, 1.6832]], device='cuda:0', grad_fn=<AsStridedBackward>)
ipdb> input.data.mean() # 檢視input的均值和标準差
tensor(3.9630, device='cuda:0')
ipdb> input.data.std()
tensor(1.9513, device='cuda:0')
ipdb> u # 跳回上一層
> c:\programdata\anaconda3\lib\site-packages\torch\nn\modules\module.py(547)__call__()
546 else:
--> 547 result = self.forward(*input, **kwargs)
548 for hook in self._forward_hooks.values():
ipdb> u # 跳回上一層
> e:\workspace\python\pytorch\chapter6\main.py(88)train()
87 optimizer.zero_grad()
1--> 88 score = model(input)
89 loss = criterion(score, target)
ipdb> clear # 清除所有斷點
Clear all breaks? y
Deleted breakpoint 1 at e:\workspace\python\pytorch\chapter6\main.py:88
ipdb> c # 繼續運作,記得先删除"debug/debug.txt",否則很快又會進入調試模式
59it [06:21, 5.75it/s]loss: 0.24856307208538073
76it [06:24, 5.91it/s]
當我們想要進入debug模式,修改程式中某些參數值或者想分析程式時,就可以通過建立debug辨別檔案,此時程式會進入調試模式,調試完成之後删除這個檔案并在ipdb調試接口輸入c繼續運作程式。如果想退出程式,也可以使用這種方式,先建立debug辨別檔案,然後輸入quit在退出debug的同時退出程式。這種退出程式的方式,與使用Ctrl+C的方式相比更安全,因為這能保證資料加載的多程序程式也能正确地退出,并釋放記憶體、顯存等資源。
PyTorch和ipdb集合能完成很多其他架構所不能完成或很難完成的功能。根據筆者日常使用的總結,主要有以下幾個部分:
(1)通過debug暫停程式。當程式進入debug模式後,将不再執行PCU和GPU運算,但是記憶體和顯存及相應的堆棧空間不會釋放。
(2)通過debug分析程式,檢視每個層的輸出,檢視網絡的參數情況。通過u§、d(own)、s(tep)等指令,能夠進入指定的代碼,通過n(ext)可以單步執行,進而看到每一層的運算結果,便于分析網絡的數值分布等資訊。
(3)作為動态圖架構,PyTorch擁有Python動态語言解釋執行的優點,我們能夠在運作程式時,用過ipdb修改某些變量的值或屬性,這些修改能夠立即生效。例如可以在訓練開始不久根據損失函數調整學習率,不必重新開機程式。
(4)如果在IPython中通過%run魔法方法運作程式,那麼在程式異常退出時,可以使用%debug指令,直接進入debug模式,通過u§和d(own)跳到報錯的地方,檢視對應的變量,找出原因後修改相應的代碼即可。有時我們的模式訓練了好幾個小時,卻在将要儲存模式之前,因為一個小小的拼寫錯誤異常退出。此時,如果修改錯誤再重新運作程式又要花費好幾個小時,太浪費時間。是以最好的方法就是看利用%debug進入調試模式,在調試模式中直接運作model.save()儲存模型。在IPython中,%pdb魔術方法能夠使得程式出現問題後,不用手動輸入%debug而自動進入debug模式,建議使用。
PyTorch調用CuDNN報錯時,報錯資訊諸如CUDNN_STATUS_BAD_PARAM,從這些報錯内容很難得到有用的幫助資訊,最後先利用PCU運作代碼,此時一般會得到相對友好的報錯資訊,例如在ipdb中執行model.cpu()(input.cpu()),PyTorch底層的TH庫會給出相對比較詳細的資訊。
常見的錯誤主要有以下幾種:
- 類型不比對問題。例如CrossEntropyLoss的輸入target應該是一個LongTensor,而很多人輸入FloatTensor。
- 部分資料忘記從CPU轉移到GPU。例如,當model存放于GPU時,輸入input也需要轉移到GPU才能輸入到model中。還有可能就是把多個model存放于一個list對象,而在執行model.cuda()時,這個list中的對象是不會被轉移到CUDA上的,正确的用法是用ModuleList代替。
- Tensor形狀不比對。此類問題一般是輸入資料形狀不對,或是網絡結構設計有問題,一般通過u§跳到指定代碼,檢視輸入和模型參數的形狀即可得知。
此外,可能還會經常遇到程式正常運作、沒有報錯,但是模型無法收斂的問題。例如對于二分類問題,交叉熵損失一直徘徊在0.69附近(ln2),或者是數值出現溢出等問題,此時可以進入debug模式,用單步執行檢視,每一層輸出的均值和方差,觀察從哪一層的輸出開始出現數值異常。還要檢視每個參數梯度的均值和方差,檢視是否出現梯度消失或者梯度爆炸等問題。一般來說,通過再激活函數之前增加BatchNorm層、合理的參數初始化、使用Adam優化器、學習率設為0.001,基本就能確定模型在一定程度收斂。
本章帶領讀者從頭實作了一個Kaggle上的經典競賽,重點講解了如何合理地組合安排程式,同時介紹了一些在PyTorch中調試的技巧。