如何使用PyTorch基于蒙版选择多维张量的所有值?

时间:2020-03-28 17:52:30

标签: python pytorch tensor

我有一个张量pred,其张量.size()[4, 53, 161]。我有另一个张量mask,其张量.size()[4, 53]

mask仅是0和1。我想做的是获取pred的值,其中mask的值为1。您会注意到pred的维度比mask大一倍。我该怎么做?

1 个答案:

答案 0 :(得分:1)

mask = mask.unsqueeze(2)
new_pred = pred * mask

这将增加额外的尺寸。现在是[4, 53, 1]。休息会照顾广播。 (如果您进行一些操作)

假设您有一个形状张量[4, 53, 164],现在您想将其简化为[4, 53],则可以应用像这样的new_pred.mean(2)算术运算。