殘差網絡效果
卷積神經網絡CNN的發展曆史如圖所示:

從起初AlexNet的的8層網絡,到ResNet的152層網絡,層數逐漸增加。當網絡層數增加到一定程度之後,錯誤率反而上升,其原因是層數太多梯度下降變得越發困難。而ResNet解決了這一問題。
目前ResNet是應用最廣的圖像相關深度學習網絡,圖像分類,目标檢測,圖檔分割都使用該網絡結構作為基礎,另外,一些遷移學習也使用ResNet訓練好的模型來提取圖像特征。
殘差網絡原理
首先,來看看比較官方的殘差網絡原理說明:
“若将輸入設為X,将某一有參網絡層設為H,那麼以X為輸入的此層的輸出将為H(X)。一般的CNN網絡如Alexnet/VGG等會直接通過訓練學習出參數函數H的表達,進而直接學習X -> H(X) 。而殘差學習則是緻力于使用多個有參網絡層來學習輸入、輸出之間的殘差即H(X) - X即學習X -> (H(X) - X) + X。其中X這一部分為直接的identity mapping,而H(X) - X則為有參網絡層要學習的輸入輸出間殘差。”
第一次看到上述文字,我似乎明白了,但了解又不一定正确。在沒看到代碼之前,對VGG/ResNet的結構原理沒什麼感覺,幾乎就是背下來哪個效果比較好,大概用了什麼技術。後來看到了Pytorch中ResNet的代碼,原來簡單到"五分鐘包會"的程度。用自然語言描述程式果然是把簡單的問題搞複雜了。
解讀核心程式
直接看代碼,不學習TensorFlow的複雜結構,也不使用生澀的公式語言,而用順序結構的Pytorch作為通往深度學習的捷徑。下面來解讀Pytorch官方版的ResNet實作。完整代碼見;
https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
Torchvision是Torch的圖像工具包,上述代碼包含在Torchvision之中,同一目錄下還有alexnet,googlenet,vgg的實作。ResNet代碼共300多行,其中核心代碼不到200行,實作了三個主要類:ResNet、BasicBlock、Bottleneck。
1.殘差是什麼,如何實作?
BasicBlock類中計算了殘差,該類繼承了nn.Module(Pytorch基本用法請見參考部分),實作了兩個函數:用于建立網絡結構的init和實作前向算法的forward。如下所示:
其中x是輸入,out是輸出,從程式代碼可以看出,與基本流程不同的是,它加入了indentity,而indentity就是輸入x本身(也支援下采樣),也就是說,在經過多層轉換得到的out上加輸入資料x,即上面所說的 H(X)+X。如果設輸出Y=H(X)+X,則有H(X)=Y-X,建構網絡H(X)用于求取輸出Y與輸入X的差異,即殘差。而之前的網絡都是直接求從X到Y的方法。
2.BasicBlock和Bottleneck
BasicBlock類用于建構網絡中的子網絡結構(後稱block),子網絡中包含兩個卷積層和殘差處理。一個ResNet包含多個BasicBlock子網絡。是以相對于傳統網絡,ResNet常被描繪成下圖的結構,右側的弧線是“+X”的操作。
Bottleneck是BasicBlock的更新版,其功能也是構造子網絡,resnet18和resnet34中使用了BasicBlock,而resnet50、resnet101、resnet152使用了Bottlenect構造網絡。
Bottleneck和BasicBlock網絡結構對比如下圖所示:
左圖中的BasicBlock包含兩個3x3的卷積層,右圖的Bottleneck包括了三個卷積層,第一個1x1的卷積層用于降維,第二個3x3層用于處理,第三個1x1層用于升維,這樣減少了計算量。
3.主要ResNet類
ResNet中最常用的是ResNet50,它兼顧了準确性和運算量。下面以RenNet50作為示例,分析建構ResNet的具體方法。
在調用_resnet建立網絡時,第二個參數指定使用Bottleneck類建構子網絡,第三個參數指定了每一層layer由幾個子網絡block構成。
下圖是ResNet的初始化部分init中,用于建構網絡結構的代碼(建議在github檢視完整代碼)。
可以看到程式用函數_make_layer建立了四個層,以resnet50為例,各個層中block的個數依次是3,4,6,3個,而每個block(Bottleneck)中又包含三個卷積層,(3+4+6+3)*3共48個卷積層,外加第141行建立的另一卷積層和第154行建立的一個全連接配接層,總共50個主要層,這也是resnet50中50的含義。
除此以外,上述torchvision程式還提供了下載下傳預測訓練的模型參數,通過設定pretrain=True/False選擇是否使用預訓練的模型。