天天看點

看圖說話,如何搞懂torch.sactter_(),用于散列數組及獨熱編碼One-hottorch.sactter_()的用法簡析

torch.sactter_()的用法簡析

明天就要開周會了,今天我的ppt和内容都還沒有做,集中注意力學習了一下

scatter_()

函數,弄清了二維和獨熱編碼是的工作原理。

如果你經常看到類似下面的東西:

y_vec_ = torch.zeros((self.batch_size, self.class_num)).scatter_(1, y_.type(torch.LongTensor).unsqueeze(1), 1)
           

經常性的感覺到困惑,檢視官方文檔,不知道内在如何操作,檢視部落格,一大堆照抄的内容,标點符号都沒有改。今天我可以講清楚 二維

torch tensor

結構的内容。

官方函數: 将src中的所有值按照index确定的索引寫入本tensor中。其中索引是根據給定的dimension,dim按照gather()描述的規則來确定。

注意,index的值必須是在_0_到_(self.size(dim)-1)_之間,

scatter_(input, dim, index, src) → Tensor
           

參數: - input (Tensor)-源tensor - dim (int)-索引的軸向 - index (LongTensor)-散射元素的索引指數 - src (Tensor or float)-散射的源元素。

部落格上,大家一上來,就列出三維torch scatter 的形式化表示,這個有點誇張!二維先搞定,三維慢慢就會明白,雖然我現在也不懂三維,但是平時用的少

這裡才是正經内容

借用官方例程

>>> x = torch.rand(2, 5)
>>> x

 0.4319  0.6500  0.4080  0.8760  0.2355
 0.2609  0.4711  0.8486  0.8573  0.1029
[torch.FloatTensor of size 2x5]

>>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)

 0.4319  0.4711  0.8486  0.8760  0.2355
 0.0000  0.6500  0.0000  0.8573  0.0000
 0.2609  0.0000  0.4080  0.0000  0.1029
[torch.FloatTensor of size 3x5]

>>> z = torch.zeros(2, 4).scatter_(1, torch.LongTensor([[2], [3]]), 1.23)
>>> z

 0.0000  0.0000  1.2300  0.0000
 0.0000  0.0000  0.0000  1.2300
[torch.FloatTensor of size 2x4]
           

看圖說話:

dim=0,可以了解為,固定列。圖檔從上往下看,順着五個小箭頭看

下圖說明

index = torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])
src = x = torch.rand(2, 5)
           
  • 将 數組

    index

    的每一行分别對應 數組

    x

    的每一行,如下圖所示:
    看圖說話,如何搞懂torch.sactter_(),用于散列數組及獨熱編碼One-hottorch.sactter_()的用法簡析
  • 固定每一列,

    index

    的元素值,作為結果行号,

    index

    每一個元素下面對應的

    x

    的元素值,為值,傳遞到結果數組中。

    第一行

    看圖說話,如何搞懂torch.sactter_(),用于散列數組及獨熱編碼One-hottorch.sactter_()的用法簡析
    第二行
    看圖說話,如何搞懂torch.sactter_(),用于散列數組及獨熱編碼One-hottorch.sactter_()的用法簡析
  • 兩行内容合并(每一行的結果相當于一個數組,兩個數組相加)
    看圖說話,如何搞懂torch.sactter_(),用于散列數組及獨熱編碼One-hottorch.sactter_()的用法簡析

結果的形式化依據,可以參考照片部落格,位址

獨熱編碼

放兩張圖。

  • dim=1

看圖說話,如何搞懂torch.sactter_(),用于散列數組及獨熱編碼One-hottorch.sactter_()的用法簡析
  • dim=0

    看圖說話,如何搞懂torch.sactter_(),用于散列數組及獨熱編碼One-hottorch.sactter_()的用法簡析