目标和输入必须具有相同数量的元素

时间:2018-06-02 07:20:12

标签: python pytorch

我的输入标签尺寸为torch.size([30, 2, 96, 96, 96])

我的标签尺寸是torch.size([30, 96, 96, 96]),我正在将它们送到我的损失功能,如下所示:

loss = F.binary_cross_entropy(F.sigmoid(output),labels,torch.FloatTensor(CLASS_WEIGHTS).cuda())

当我运行时,我得到了

Value error:Target and input must have the same number of elements.target nelement(26542080) != input nelement(53084160)

我在这里有点困惑。我得到的输入值是目标值的两倍,因为它将[30,96,96]乘以类的数量,但我不确定为什么会这样,以及如何纠正它。任何建议都会有所帮助,提前谢谢。

1 个答案:

答案 0 :(得分:0)

与获取目标标签值的torch.nn.CrossEntropyLoss图层不同(即如果input具有(30, C, 96, 96, 96)C个类的target(30, 96, 96, 96)必须是input),torch.nn.functional.binary_cross_entropy()需要targettarget具有相同的形状(即(30, C, 96, 96, 96)形状torch.nn.CrossEntropyLoss),所以它需要目标标签的一个热门表示。

除非您选择def to_one_hot(x, C=2, tensor_class=torch.FloatTensor): """ One-hot a batched tensor of shape (B, ...) into (B, C, ...) """ x_one_hot = tensor_class(x.size(0), C, *x.shape[1:]).zero_() x_one_hot = x_one_hot.scatter_(1, x.unsqueeze(1), 1) return x_one_hot # Demonstration: num_classes = 2 labels = torch.LongTensor(30, 96, 96, 96).random_(0, num_classes) one_hot_labels = to_one_hot(labels, C=num_classes) print(one_hot_labels.shape) # > torch.Size([30, 2, 96, 96, 96]) ,否则您可以采用多种方法来加热目标广告(例如,请参阅此thread)。个人解决方案:

TCP/3306 FROM 10.0.1.0/24