天天看點

torch.multinomial()

先看一下定義(沒上官網,在外,熱點官網上不去)

參數:

  • input:我也不知道為何上面的函數參數中沒有,預設第一個參數是input,是一個Tensor,裡面的元素值越大,那麼這個元素就越有可能被取到,取出的不是這個元素值,而是這個元素對應的下标
  • num_samples:整數,表示要取多少個下标
  • replacement:True表示有放回的采樣,False表示不放回,取一個少一個
  • generator暫時沒去了解是什麼,懂的評論教我一手,我隻知道類型必須是torch.Generator,好像必須要CUDA才能用。。

實驗

import torch
# import numpy as np
freq = torch.Tensor([0.1,0.2,0.3,0.9])
K = 2
# generate = torch.Generator
res = torch.multinomial(input = freq,num_samples=K,replacement=True)
print(res)
           

結果

基本上每次都是:
tensor([3, 3])
但也有别的情況