我正在尝试根据索引将一个Tensor添加到另一个Tensor。我认为scatter_add正是为了做到这一点。这是我为此编写的测试代码。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
import tensorflow as tf
from tensorflow.python.ops import array_ops
with tf.Graph().as_default():
sess = tf.Session()
a = tf.get_variable("test",[3,3,2], dtype=tf.int32, initializer=tf.constant_initializer(0))
b = tf.constant([[1,1],[3,3]],dtype=tf.int32)
c = tf.constant([[0,1],[2,2]],dtype=tf.int32)
d = tf.scatter_add(a,b,c,use_locking=None)
init = tf.initialize_all_variables()
sess.run(init)
print(sess.run([d]))
最终输出应为:
[[[0,0], [1,1], [0,0]],
[[0,0], [0,0], [0,0]],
[[0,0], [0,0], [3,3]]]
但是我收到了这个错误:
Traceback (most recent call last):
File "test.py", line 16, in <module>
d = tf.scatter_add(a,b,c,use_locking=None)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_state_ops.py", line 227, in scatter_add
name=name)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 703, in apply_op
op_def=op_def)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2312, in create_op
set_shapes_for_outputs(ret)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1704, in set_shapes_for_outputs
shapes = shape_func(op)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/state_ops.py", line 235, in _ScatterUpdateShape
indices_shape.concatenate(var_shape[1:]))
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/tensor_shape.py", line 570, in merge_with
(self, other))
ValueError: Shapes (2, 2) and (2, 2, 3, 2) are not compatible