如何将pytorch的整数张量转换为布尔的张量?

时间:2018-11-30 17:45:21

标签: python casting pytorch

我想将一个整数张量转换为一个布尔张量。

具体来说,我希望能够将tensor([0,10,0,16])转换为tensor([0,1,0,1])

的函数

仅使用tf.cast(x,tf.bool),这在Tensorflow中是微不足道的。

我希望强制类型转换将所有大于0的整数更改为1,并将所有等于0的整数更改为0。在大多数语言中,这相当于!!

由于pytorch似乎没有强制转换的专用布尔类型,因此最佳方法是什么?

编辑:我正在寻找一种与循环遍历每个元素相对的矢量化解决方案。

5 个答案:

答案 0 :(得分:2)

您可以使用下面的代码片段中所示的比较。

 a = tensor([0,10,0,16])
 result = (a == 0)

会给出

 tensor([1, 0, 1, 0], dtype=torch.uint8)

答案 1 :(得分:2)

另一种选择是简单地做:

temp = torch.tensor([0,10,0,16])
temp.bool()
#Returns
tensor([False,  True, False,  True])

答案 2 :(得分:1)

您要寻找的是为给定的整数张量生成一个布尔掩码。为此,您可以简单地检查条件:值是否大于0,这将给出所需的结果。

# input tensor
In [76]: t   
Out[76]: tensor([ 0, 10,  0, 16])

# generate the needed boolean mask
In [78]: t > 0      
Out[78]: tensor([0, 1, 0, 1], dtype=torch.uint8)

# sanity check
In [93]: mask = t > 0      

In [94]: mask.type()      
Out[94]: 'torch.ByteTensor'

答案 3 :(得分:1)

将布尔值转换为数字值:

a = torch.tensor([0,4,0,0,5,0.12,0.34,0,0])
print(a.gt(0)) # output in boolean dtype
# output: tensor([False,  True, False, False,  True,  True,  True, False, False])

print(a.gt(0).to(torch.float32)) # output in float32 dtype
# output: tensor([0., 1., 0., 0., 1., 1., 1., 0., 0.])

答案 4 :(得分:0)

PyTorch 的 <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.3.1/jquery.min.js"></script> <input type="text" placeholder="1-5,9"> <table border="border"> <thead> <tr> <th class="text-center">data</th> <th class="text-center">data</th> <th class="text-center">data</th> </tr> </thead> <tbody> </tbody> </table> 方法有方便的 data-type named aliases。您可以简单地调用 to(dtype):

bool
>>> t.bool()
tensor([False,  True, False,  True])