如何摆脱Pytorch张量中充满零的每一列?

时间:2019-11-27 18:12:50

标签: python pytorch tensor torchtext

我有一个如下所示的火炬张量A

A = 
tensor([[  4,   3,   3,  ...,   0,   0,   0],
        [ 13,   4,  13,  ...,   0,   0,   0],
        [707, 707,   4,  ...,   0,   0,   0],
        ...,
        [  7,   7,   7,  ...,   0,   0,   0],
        [  0,   0,   0,  ...,   0,   0,   0],
        [195, 195, 195,  ...,   0,   0,   0]], dtype=torch.int32)

我想:

  • 标识所有条目均等于0的所有列
  • 仅删除所有条目都等于0的列

我可以想象这样做:

zero_list = []
for j in range(A.size()[1]):
    if torch.sum(A[:,j]) == 0:
         zero_list = zero_list.append(j)

标识元素只有0的列 但我不确定如何从原始张量中删除填充为0的此类列。

如何根据索引号从pytorch张量中删除零列?

谢谢

2 个答案:

答案 0 :(得分:1)

索引要保留的列而不是要删除的列更有意义。

valid_cols = []
for col_idx in range(A.size(1)):
    if not torch.all(A[:, col_idx] == 0):
        valid_cols.append(col_idx)
A = A[:, valid_cols]

或更神秘的

valid_cols = [col_idx for col_idx, col in enumerate(torch.split(A, 1, dim=1)) if not torch.all(col == 0)]
A = A[:, valid_cols]

答案 1 :(得分:1)

标识所有其条目都等于0的所有列

MyClientErrorMock

这将对每一列的绝对值求和,然后将结果转换为布尔值,即如果总和为零,则将其转换为布尔值non_empty_mask = A.abs().sum(dim=0).bool(),否则将其转换为False

仅删除所有条目都等于0的列

True

这只是将掩码应用于原始张量,即,它保留了A[:,non_empty_mask]non_empty_mask的行。