我想将一个整数张量转换为一个布尔张量。
具体来说,我希望能够将tensor([0,10,0,16])
转换为tensor([0,1,0,1])
仅使用tf.cast(x,tf.bool)
,这在Tensorflow中是微不足道的。
我希望强制类型转换将所有大于0的整数更改为1,并将所有等于0的整数更改为0。在大多数语言中,这相当于!!
。
由于pytorch似乎没有强制转换的专用布尔类型,因此最佳方法是什么?
编辑:我正在寻找一种与循环遍历每个元素相对的矢量化解决方案。
答案 0 :(得分:2)
您可以使用下面的代码片段中所示的比较。
a = tensor([0,10,0,16])
result = (a == 0)
会给出
tensor([1, 0, 1, 0], dtype=torch.uint8)
答案 1 :(得分:2)
另一种选择是简单地做:
temp = torch.tensor([0,10,0,16])
temp.bool()
#Returns
tensor([False, True, False, True])
答案 2 :(得分:1)
您要寻找的是为给定的整数张量生成一个布尔掩码。为此,您可以简单地检查条件:值是否大于0,这将给出所需的结果。
# input tensor
In [76]: t
Out[76]: tensor([ 0, 10, 0, 16])
# generate the needed boolean mask
In [78]: t > 0
Out[78]: tensor([0, 1, 0, 1], dtype=torch.uint8)
# sanity check
In [93]: mask = t > 0
In [94]: mask.type()
Out[94]: 'torch.ByteTensor'
答案 3 :(得分:1)
将布尔值转换为数字值:
a = torch.tensor([0,4,0,0,5,0.12,0.34,0,0]) print(a.gt(0)) # output in boolean dtype # output: tensor([False, True, False, False, True, True, True, False, False]) print(a.gt(0).to(torch.float32)) # output in float32 dtype # output: tensor([0., 1., 0., 0., 1., 1., 1., 0., 0.])
答案 4 :(得分:0)
PyTorch 的 <script src="https://cdnjs.cloudflare.com/ajax/libs/jquery/3.3.1/jquery.min.js"></script>
<input type="text" placeholder="1-5,9">
<table border="border">
<thead>
<tr>
<th class="text-center">data</th>
<th class="text-center">data</th>
<th class="text-center">data</th>
</tr>
</thead>
<tbody>
</tbody>
</table>
方法有方便的 data-type named aliases。您可以简单地调用 to(dtype)
:
bool
>>> t.bool()
tensor([False, True, False, True])