如何将张量中的列设置为0s?

时间:2019-04-16 16:20:22

标签: python tensorflow

假设我的张量看起来像:

[[4, 5, 6],
 [7, 8, 9],
 [4, 125, 6],
 [72, 81, 91]]

并且我希望第一列和最后一列中的所有值均为0,所以结果看起来像这样:

[[0, 5, 0],
 [0, 8, 0],
 [0, 125, 0],
 [0, 81, 0]]

在tensorflow(1.4版)中最简单,最有效的方法是什么?

我尝试使用tf.boolean_mask,但是我的tf版本不支持axis参数,并且似乎删除了列而不是将它们归零。

任何帮助将不胜感激,谢谢!

编辑: 输入张量可以具有大于3的任意数量的列,而输出张量将是具有恰好“清零”的两列(指定)的张量。

1 个答案:

答案 0 :(得分:1)

使用tf.one_hot()在第二列中创建一个带有张量的掩码张量,然后与您想要第二列中的张量进行元素逐个相乘:

import tensorflow as tf

tensor = tf.constant([[4, 5, 6],
                      [7, 8, 9],
                      [4, 125, 6],
                      [72, 81, 91]], dtype=tf.float32)

col_to_zero = [0, 2] # <-- column numbers you want to be zeroed out
tnsr_shape = tf.shape(tensor)
mask = [tf.one_hot(col_num*tf.ones((tnsr_shape[0], ), dtype=tf.int32), tnsr_shape[-1])
        for col_num in col_to_zero]
mask = tf.reduce_sum(mask, axis=0)
mask = tf.cast(tf.logical_not(tf.cast(mask, tf.bool)), tf.float32)

result = tensor * mask

with tf.Session() as sess:
    print(result.eval())
# [[  0.   5.   0.]
#  [  0.   8.   0.]
#  [  0. 125.   0.]
#  [  0.  81.   0.]]

col_to_zero是要清零的列号的列表。例如,设置col_to_zero = [1, 2]将仅保留示例的第一列,并将所有其他列归零。