PyTorch [如果x> 0.5则为1,否则对于x在输出中为0]带张量

时间:2019-09-19 02:18:36

标签: python pytorch torchvision

我在Symoid函数中有一个列表输出,作为PyTorch中的张量

例如

output (type) = torch.Size([4]) tensor([0.4481, 0.4014, 0.5820, 0.2877], device='cuda:0',

在进行二进制分类时,我想将所有值都从0.5降到0,将大于0.5降到1。

传统上,您可以使用NumPy数组使用列表迭代器:

output_prediction = [1 if x > 0.5 else 0 for x in outputs ]

这可以工作,但是我稍后必须将output_prediction转换回张量以供使用

torch.sum(ouput_prediction == labels.data)

labels.data是标签的二进制张量。

是否可以使用带有张量的列表迭代器?

2 个答案:

答案 0 :(得分:4)

prob = torch.tensor([0.3,0.4,0.6,0.7])

out = (prob>0.5).float()
# tensor([0.,0.,1.,1.])

说明:在pytorch中,您可以直接使用prob>0.5获得torch.bool类型的张量。然后,您可以通过.float()转换为浮点类型。

答案 1 :(得分:0)

为什么不考虑使用 loop less解决方案?也许像下面这样就足够了:

In [34]: output = torch.tensor([0.4481, 0.4014, 0.5820, 0.2877]) 

# subtract off the threshold value (0.5), create a boolean mask, 
# and then cast the resultant tensor to an `int` type
In [35]: result = torch.as_tensor((output - 0.5) > 0, dtype=torch.int32) 

In [36]: result        
Out[36]: tensor([0, 0, 1, 0], dtype=torch.int32)