天天看點

【Pytorch】detach, requires_grad和volatile

在跑CIN的代碼時,将batch_size從10一路降到2,依然每執行sample就爆顯存.請教師兄後,發現問題出在這一句上:

【Pytorch】detach, requires_grad和volatile

在進行sample的時候,不止儲存之前的變量fake,而且還儲存了fake前所有的梯度.計算圖進行累積,那樣不管有多大顯存都是放不下的.

之後,在self.G(real_x, target_c)[0]後面加上了.detach(),代碼就可以順利運作了.

查閱pytorch的官方文檔,上面是這麼說的:

【Pytorch】detach, requires_grad和volatile

簡單來說,就是建立一個新的tensor,将其從目前的計算圖中分離出來.新的tensor與之前的共享data,但是不具有梯度.在任意一個tensor上進行原地操作都會報錯(what?)

進行驗證發現,v_c是具有梯度的,但是進行detach之後建立的新變量v_c_detached是不具有梯度的.

【Pytorch】detach, requires_grad和volatile

對v_c_detached進行修改,v_c的data值也會改變.說明他們是共享同一塊顯存的.

【Pytorch】detach, requires_grad和volatile

在pytorch中,autograd是由計算圖實作的.Variable是autograd的核心資料結構,其構成分為三部分: data(tensor), grad(也是Variable), grad_fn(得到這一節點的直接操作).對于requires_grad為false的節點,是不具有grad的.

【Pytorch】detach, requires_grad和volatile

使用者自己建立的節點是leaf_node(如圖中的abc三個節點),不依賴于其他變量,對于leaf_node不能進行in_place操作.根節點是計算圖的最終目标(如圖y),通過鍊式法則可以計算出所有節點相對于根節點的梯度值.這一過程通過調用root.backward()就可以實作.

是以,detach所做的就是,重新聲明一個變量,指向原變量的存放位置,但是requires_grad為false.更深入一點的了解是,計算圖從detach過的變量這裡就斷了, 它變成了一個leaf_node.即使之後重新将它的requires_node置為true,它也不會具有梯度.

【Pytorch】detach, requires_grad和volatile

另一方面,在調用完backward函數之後,非leaf_node的梯度計算完會立刻被清空.這也是為什麼在執行backward之前顯存占用很大,執行完之後顯存占用立刻下降很多的原因.當然,這其中也包含了一些中間結果被存在buffer中,調用結束後也會被釋放.

至于另一個參數volatile,如果一個變量的volatile=true,它可以将所有依賴于它的節點全部設為volatile=true,優先級高于requires_grad=true.這樣的節點不會進行求導,即使requires_grad為真,也無法進行反向傳播.在inference中如果采用這種設定,可以實作一定程度的速度提升,并且節約大概一半顯存.

作者:nowherespyfly

連結:https://www.jianshu.com/p/f1bd4ff84926

來源:簡書

簡書著作權歸作者所有,任何形式的轉載都請聯系作者獲得授權并注明出處。

繼續閱讀