天天看點

了解pytorch的 backward(), 究竟幹了什麼

很多的文字看着頭疼, 我們來看看下面代碼,一看就懂

x=th.ones(1,requires_grad=True)
z=x+1
y=z**3
t=y+2
m=t*6


print(m.grad)
print(t.grad)
print(y.grad)
print(z.grad)
print(x.grad)
           

結果是:

None
None
None
None
None
           

第二段

x=th.ones(1,requires_grad=True)
z=x+1
y=z**3
t=y+2
m=t*6

m.backward()
print(m.grad)
print(t.grad)
print(y.grad)
print(z.grad)
print(x.grad)
           

結果是

None
None
None
None
tensor([72.])

           

為啥呢? 為啥第一次,導數都是None?

為啥第二次,隻有最後一個x 列印出來資料了呢?

很簡單: 如果你想列印出來grad ,要滿足兩個條件⭐

  1. 隻有聲明了“ requires_grad=True ” 的元素,才可以調用 .grad!列印出資料。 其他的,不管你是不是 x 的複合函數,都不能列印 grad
  2. 要想列印出 grad ,必須先運作 backward , 這是一個求導的動作,沒有這個backward ,倒數就是None

現在就很清楚了,由于 m ,t , y, z ,都沒有聲明自己的 require_grad是 True。 是以第一個程式中,既因為前面沒有backward,又因為自己沒有require,是以導數永遠就是0

對于x ,第一個程式沒有 backward ,是以導數是None; 第二個程式,滿足了上面兩個條件,就能成功列印了!

再看一個例子

x=th.ones(1,requires_grad=True)
z=x+1
y=z**3
t=y+2
m=t*6

t.backward()
print(m.grad)
print(t.grad)
print(y.grad)
print(z.grad)
print(x.grad)
           

結果是:

None
None
None
None
tensor([12.])
           

這裡,注意, 我們把 從m 的backward 換成了 從t 的backward,

也就意味着, 反向求導的位置不同。 我們口算一下 ,

發現 t對x 的導數就是 12, m對x的導數就是72

你明白了吧 , 從不同位置進行 backward ,結果也是不同的。

這就是簡單的對于 backward的認識,謝謝閱讀。

繼續閱讀