我有一个张量pred
,其张量.size()
为[4, 53, 161]
。我有另一个张量mask
,其张量.size()
为[4, 53]
。
mask
仅是0和1。我想做的是获取pred
的值,其中mask
的值为1
。您会注意到pred
的维度比mask
大一倍。我该怎么做?
答案 0 :(得分:1)
mask = mask.unsqueeze(2)
new_pred = pred * mask
这将增加额外的尺寸。现在是[4, 53, 1]
。休息会照顾广播。 (如果您进行一些操作)
假设您有一个形状张量[4, 53, 164]
,现在您想将其简化为[4, 53]
,则可以应用像这样的new_pred.mean(2)
算术运算。