我有一个3-D张量,我想在这个numpy代码中更新一些值:
x_np = np.zeros((2,3,2))
updates_np = np.ones((2,2))
indices_1 = np.array([0, 1])
indices_2 = np.array([1, 1])
x_np[indices_1, indices_2, :] = updates_np
这就是结果:
>>> x_np
array([[[ 0., 0.],
[ 1., 1.],
[ 0., 0.]],
[[ 0., 0.],
[ 1., 1.],
[ 0., 0.]]])
因为我可以在最后一个维度上更新完整行,所以为此我尝试使用scatter_update(我也尝试使用dynamic_stitch)。但似乎scatter_update不接受多个维度作为更新索引。
我怎么能在张量流中做到这一点?
答案 0 :(得分:0)
我找到了一个解决方案,对我来说看起来很糟糕,所以欢迎任何改进
sess = tf.InteractiveSession()
x = tf.zeros((2,3,2))
indices_1 = tf.constant([0, 1])
indices_2 = tf.constant([1, 1])
updates_tf = tf.ones((2, 2))
x_temp = tf.Variable(tf.reshape(x, [-1, 2]))
indices_flat = tf.shape(x)[1]*indices_1 + indices_2
sess.run(tf.initialize_all_variables())
print(sess.run(tf.reshape(tf.scatter_update(x_temp, indices_flat, updates_tf), tf.shape(x))))
结果
[[[ 0. 0.]
[ 1. 1.]
[ 0. 0.]]
[[ 0. 0.]
[ 1. 1.]
[ 0. 0.]]]