torch.sactter_()的用法簡析
明天就要開周會了,今天我的ppt和内容都還沒有做,集中注意力學習了一下
scatter_()
函數,弄清了二維和獨熱編碼是的工作原理。
如果你經常看到類似下面的東西:
y_vec_ = torch.zeros((self.batch_size, self.class_num)).scatter_(1, y_.type(torch.LongTensor).unsqueeze(1), 1)
經常性的感覺到困惑,檢視官方文檔,不知道内在如何操作,檢視部落格,一大堆照抄的内容,标點符号都沒有改。今天我可以講清楚 二維
torch tensor
結構的内容。
官方函數: 将src中的所有值按照index确定的索引寫入本tensor中。其中索引是根據給定的dimension,dim按照gather()描述的規則來确定。
注意,index的值必須是在_0_到_(self.size(dim)-1)_之間,
scatter_(input, dim, index, src) → Tensor
參數: - input (Tensor)-源tensor - dim (int)-索引的軸向 - index (LongTensor)-散射元素的索引指數 - src (Tensor or float)-散射的源元素。
部落格上,大家一上來,就列出三維torch scatter 的形式化表示,這個有點誇張!二維先搞定,三維慢慢就會明白,雖然我現在也不懂三維,但是平時用的少
這裡才是正經内容
借用官方例程
>>> x = torch.rand(2, 5)
>>> x
0.4319 0.6500 0.4080 0.8760 0.2355
0.2609 0.4711 0.8486 0.8573 0.1029
[torch.FloatTensor of size 2x5]
>>> torch.zeros(3, 5).scatter_(0, torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
0.4319 0.4711 0.8486 0.8760 0.2355
0.0000 0.6500 0.0000 0.8573 0.0000
0.2609 0.0000 0.4080 0.0000 0.1029
[torch.FloatTensor of size 3x5]
>>> z = torch.zeros(2, 4).scatter_(1, torch.LongTensor([[2], [3]]), 1.23)
>>> z
0.0000 0.0000 1.2300 0.0000
0.0000 0.0000 0.0000 1.2300
[torch.FloatTensor of size 2x4]
看圖說話:
dim=0,可以了解為,固定列。圖檔從上往下看,順着五個小箭頭看
下圖說明
index = torch.LongTensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]])
src = x = torch.rand(2, 5)
- 将 數組
的每一行分别對應 數組index
的每一行,如下圖所示:x
看圖說話,如何搞懂torch.sactter_(),用于散列數組及獨熱編碼One-hottorch.sactter_()的用法簡析 - 固定每一列,
的元素值,作為結果行号,index
每一個元素下面對應的index
x
的元素值,為值,傳遞到結果數組中。
第一行
第二行看圖說話,如何搞懂torch.sactter_(),用于散列數組及獨熱編碼One-hottorch.sactter_()的用法簡析 看圖說話,如何搞懂torch.sactter_(),用于散列數組及獨熱編碼One-hottorch.sactter_()的用法簡析 - 兩行内容合并(每一行的結果相當于一個數組,兩個數組相加)
看圖說話,如何搞懂torch.sactter_(),用于散列數組及獨熱編碼One-hottorch.sactter_()的用法簡析
結果的形式化依據,可以參考照片部落格,位址
獨熱編碼
放兩張圖。
-
dim=1
-
dim=0
看圖說話,如何搞懂torch.sactter_(),用于散列數組及獨熱編碼One-hottorch.sactter_()的用法簡析