pytorch,样品给出批处理logits

时间:2020-08-20 12:55:34

标签: pytorch

给出类似的登录信息

# each row is a record of data
logits = np.array([ [0.1, 0.3, 0.5], [0.3, 0.1, 0.5], [0.1, 0.3, 0.0] ])

如何使用Pytorch对每一行的logit索引进行采样?当前的分发API不支持此类功能。

例如,我想要的是

distribution = Categorical(logits=logits)
labels = distribution.sample(dim=1)

0 个答案:

没有答案