天天看點

Pytorch的modle.train,model.eval,with torch.no_grad的個人了解

1. 最近在學習pytorch過程中遇到了幾個問題,不了解為什麼在訓練和測試函數中model.eval(),和model.train()的差別,經查閱後做如下整理

一般情況下,我們訓練過程如下:

  1. 拿到資料後進行訓練,在訓練過程中,使用
model.train():告訴我們的網絡,這個階段是用來訓練的,可以更新參數。
  1. 訓練完成後進行預測,在預測過程中,使用
model.eval() : 告訴我們的網絡,這個階段是用來測試的,于是模型的參數在該階段不進行更新。

2. 但是為什麼在eval()階段會使用with torch.no_grad()?

查閱相關資料:​​傳送門​​

with torch.no_grad - disables tracking of gradients in autograd.

model.eval() changes the forward() behaviour of the module it is called upon

       eg, it disables dropout and has batch norm use the entire population statistics