天天看點

【幹貨】深度學習實驗流程及 PyTorch 提供的解決方案

常見的 Research workflow

某一天, 你坐在實驗室的椅子上, 突然:

1 你腦子裡迸發出一個 idea

2 你看了關于某一 theory 的文章, 想試試: 要是把 xx 也加進去會怎麼樣

3 你老闆突然給你一張紙, 然後說: 那個誰, 來把這個東西實作一下

于是, 你設計了實驗流程, 并為這一 idea 挑選了合适的資料集和運作環境, 然後你廢寝忘食的實作了模型, 經過長時間的訓練和測試, 你發現:

1 這 idea 不 work --> 那算了 or 再調調

2 這 idea 很 work --> 可以寫 paper 了

我們可以把上述流程用下圖表示:

【幹貨】深度學習實驗流程及 PyTorch 提供的解決方案

實際上, 常見的流程由下面幾項組成起來:

【幹貨】深度學習實驗流程及 PyTorch 提供的解決方案

1   一旦標明了資料集, 你就要寫一些函數去 load 資料集, 然後 pre-process 資料集, normalize 資料集, 可以說這是一個實驗中占比重最多的部分, 因為:

2   每個資料集的格式都不太一樣

3   預處理和正則化的方式也不盡相同

4   需要一個快速的 dataloader 來 feed data, 越快越好

5   然後, 你就要實作自己的模型, 如果你是 CV 方向的你可能想實作一個 ResNet, 如果你是 NLP 相關的你可能想實作一個 Seq2Seq

6   接下來, 你需要實作訓練步驟, 分 batch, 循環 epoch

7   在若幹輪的訓練後, 總要 checkpoint 一下, 才是最安全的

8   你還需要建構一些 baseline, 以驗證自己 idea 的有效性

9   如果你實作的是神經網絡模型, 當然離不開 GPU 的支援

10 很多深度學習架構提供了常見的損失函數, 但大部分時間, 損失函數都要和具體任務結合起來, 然後重新實作

11 使用優化方法, 優化建構的模型, 動态調整學習率

Pytorch 給出的解決方案

對于加載資料, Pytorch 提出了多種解決辦法

1 Pytorch 是一個 Python 包, 而不是某些大型 C++ 庫的 Python 接口, 是以, 對于資料集本身提供 Python API 的, Pytorch 可以直接調用, 不必特殊處理.

2 Pytorch 內建了常用資料集的 data loader

3 雖然以上措施已經能涵蓋大部分資料集了, 但 Pytorch 還開展了兩個項目: vision, 和 text, 見下圖灰色背景部分. 這兩個項目, 采用衆包機制, 收集了大量的 dataloader, pre-process 以及 normalize, 分别對應于圖像和文本資訊.

【幹貨】深度學習實驗流程及 PyTorch 提供的解決方案

4 如果你要自定義資料集,也隻需要繼承 torch.utils.data.dataset

對于構模組化型, Pytorch 也提供了三種方案

1 衆包的模型: torch.utils.model_zoo , 你可以使用這個工具, 加載大家共享出來的模型

2 使用 torch.nn.Sequential 子產品快速建構

內建 torch.nn.Module 深度定制

對于訓練過程的 Pytorch 實作

你當然可以自己實作資料的 batch, shuffer 等, 但 Pytorch 建議用類 torch.utils.data.DataLoader 加載資料,并對資料進行采樣,生成<code>batch</code>疊代器。

對于儲存和加載模型 Pytorch 提供兩種方案

儲存和加載整個網絡

儲存和加載網絡中的參數

對于 GPU 支援

你可以直接調用 Tensor 的. cuda() 直接将 Tensor 的資料遷移到 GPU 的顯存上, 當然, 你也可以用. cpu() 随時将資料移回記憶體

對于 Loss 函數, 以及自定義 Loss

在 Pytorch 的包 torch.nn 裡, 不僅包含常用且經典的 Loss 函數, 還會實時跟進新的 Loss 包括: CosineEmbeddingLoss, TripletMarginLoss 等.

如果你的 idea 非常新穎, Pytorch 提供了三種自定義 Loss 的方式

繼承 torch.nn.module

然後

這樣做, 你能夠用 torch.nn.functional 裡優化過的各種函數來組成你的 Loss

繼承 torch.autograd.Function

這樣做,你能夠用常用的 numpy 和 scipy 函數來組成你的 Loss

寫一個 Pytorch 的 C 擴充

這裡就不細講了,未來會有内容專門介紹這一部分。

對于優化算法以及調節學習率

Pytorch 內建了常見的優化算法, 包括 SGD, Adam, SparseAdam, AdagradRMSprop, Rprop 等等.

torch.optim.lr_scheduler 提供了多種方式來基于 epoch 疊代次數調節學習率 torch.optim.lr_scheduler.ReduceLROnPlateau 還能夠基于實時的學習結果, 動态調整學習率.

原文釋出時間為:2018-02-11

本文來自雲栖社群合作夥伴新智元,了解相關資訊可以關注“AI_era”微信公衆号