给出类似的登录信息
# each row is a record of data
logits = np.array([ [0.1, 0.3, 0.5], [0.3, 0.1, 0.5], [0.1, 0.3, 0.0] ])
如何使用Pytorch对每一行的logit索引进行采样?当前的分发API不支持此类功能。
例如,我想要的是
distribution = Categorical(logits=logits)
labels = distribution.sample(dim=1)