先看一下定義(沒上官網,在外,熱點官網上不去)
參數:
- input:我也不知道為何上面的函數參數中沒有,預設第一個參數是input,是一個Tensor,裡面的元素值越大,那麼這個元素就越有可能被取到,取出的不是這個元素值,而是這個元素對應的下标
- num_samples:整數,表示要取多少個下标
- replacement:True表示有放回的采樣,False表示不放回,取一個少一個
- generator暫時沒去了解是什麼,懂的評論教我一手,我隻知道類型必須是torch.Generator,好像必須要CUDA才能用。。
實驗
import torch
# import numpy as np
freq = torch.Tensor([0.1,0.2,0.3,0.9])
K = 2
# generate = torch.Generator
res = torch.multinomial(input = freq,num_samples=K,replacement=True)
print(res)
結果
基本上每次都是:
tensor([3, 3])
但也有别的情況