天天看點

pytorch中的函數摘錄

Categorical.log_prob()

log_prob

takes the log of the probability (of some actions). Example:

import torch
from torch.distributions import Categorical
import torch.nn.functional as F

action_logits = torch.rand(5)
action_probs = F.softmax(action_logits, dim=-1)
print(action_probs)
dist = Categorical(action_probs)
action = dist.sample()
print(action)
print(dist.log_prob(action), torch.log(action_probs[action]))
           

輸出

tensor([0.1419, 0.3035, 0.1763, 0.1427, 0.2355])
tensor(2)
tensor(-1.7358) tensor(-1.7358)
           

即 l o g e ( 0.1763 ) log_e(0.1763) loge​(0.1763)