使用tensorflow中的数组更新张量

时间:2017-12-05 14:24:27

标签: python tensorflow

我有一个3-D张量,我想在这个numpy代码中更新一些值:

x_np = np.zeros((2,3,2))
updates_np = np.ones((2,2))
indices_1 = np.array([0, 1])
indices_2 = np.array([1, 1])

x_np[indices_1, indices_2, :] = updates_np

这就是结果:

>>> x_np

array([[[ 0.,  0.],
        [ 1.,  1.],
        [ 0.,  0.]],

       [[ 0.,  0.],
        [ 1.,  1.],
        [ 0.,  0.]]])

因为我可以在最后一个维度上更新完整行,所以为此我尝试使用scatter_update(我也尝试使用dynamic_stitch)。但似乎scatter_update不接受多个维度作为更新索引。

我怎么能在张量流中做到这一点?

1 个答案:

答案 0 :(得分:0)

我找到了一个解决方案,对我来说看起来很糟糕,所以欢迎任何改进

sess = tf.InteractiveSession()
x = tf.zeros((2,3,2))
indices_1 = tf.constant([0, 1])
indices_2 = tf.constant([1, 1])
updates_tf = tf.ones((2, 2))  

x_temp = tf.Variable(tf.reshape(x, [-1, 2]))

indices_flat = tf.shape(x)[1]*indices_1 + indices_2
sess.run(tf.initialize_all_variables())
print(sess.run(tf.reshape(tf.scatter_update(x_temp, indices_flat,  updates_tf), tf.shape(x))))

结果

[[[ 0.  0.]
  [ 1.  1.]
  [ 0.  0.]]

 [[ 0.  0.]
  [ 1.  1.]
  [ 0.  0.]]]