给定列*列表*,如何填充零列的TF张量

时间:2019-02-01 14:25:21

标签: python python-3.x tensorflow

在tensorflow中,我尝试在给定特定列列表的情况下填充零列的张量。

如何在tensorflow中实现它?我尝试使用tf.assigntf.scatter_nd,但遇到了一些错误。

这是一个简单的 numpy 实现

a_np = np.array([[1, 2],
                 [3, 4], 
                 [5, 6]])
columns = [1, 5]
a_padded = np.zeros((3, 7))
a_padded[:, columns] = a_np
print(a_padded)

## output ##

[[0. 1. 0. 0. 0. 2. 0.]
 [0. 3. 0. 0. 0. 4. 0.]
 [0. 5. 0. 0. 0. 6. 0.]]

我试图在张量流中做同样的事情:

a = tf.constant([[1, 2],
                 [3, 4], 
                 [5, 6]])
columns = [1, 5]
a_padded = tf.Variable(tf.zeros((3, 7)))
a_padded[:, columns].assign(a)

但这会产生以下错误:

  

TypeError:只能将列表(而不是“ int”)连接到列表

我也尝试使用tf.scatter_nd

a = tf.constant([[1, 2],
                 [3, 4], 
                 [5, 6]])
columns = [1, 5]
shape = tf.constant((3, 7))
tf.scatter_nd(columns, a, shape)

但这会产生以下错误:

  

InvalidArgumentError:输出形状的内部尺寸必须与更新形状的内部尺寸匹配。输出:[3,7]更新:[3,2] [Op:ScatterNd]

2 个答案:

答案 0 :(得分:1)

这是一个解决方案:

tf.reset_default_graph()
a = tf.constant([[1, 2], [3, 4], [5, 6]], dtype=tf.int32)
columns = tf.constant([1, 5], dtype=tf.int32)
a_padded = tf.Variable(tf.zeros((3, 7), dtype=tf.int32))
indices = tf.stack(tf.meshgrid(tf.range(tf.shape(a_padded)[0]), columns, indexing='ij'), axis=-1)
update_cols = tf.scatter_nd_update(a_padded, indices, a)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
print(sess.run(update_cols))

答案 1 :(得分:0)

(这里是OP),我设法使用tf.scatter_nd找到了解决方案。诀窍是对齐a的尺寸,列和输出形状。

a_np = np.array([[1, 2],
                 [3, 4], 
                 [5, 6]])

# Note the Transpose on every line below
a = tf.constant(a_np.T) 
columns = tf.constant(np.array([[1, 5]]).T.astype('int32'))
shape = tf.constant((7, 3))
a_padded = tf.transpose(tf.scatter_nd(columns, a, shape))