如何在张量流中的多维张量中执行scatter_add?

时间:2016-08-28 20:22:42

标签: tensorflow

我正在尝试根据索引将一个Tensor添加到另一个Tensor。我认为scatter_add正是为了做到这一点。这是我为此编写的测试代码。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
import tensorflow as tf
from tensorflow.python.ops import array_ops


with tf.Graph().as_default():
    sess = tf.Session()

    a = tf.get_variable("test",[3,3,2], dtype=tf.int32, initializer=tf.constant_initializer(0))
    b = tf.constant([[1,1],[3,3]],dtype=tf.int32)
    c = tf.constant([[0,1],[2,2]],dtype=tf.int32)
    d = tf.scatter_add(a,b,c,use_locking=None)

    init = tf.initialize_all_variables()
    sess.run(init)
    print(sess.run([d]))    

最终输出应为:

[[[0,0], [1,1], [0,0]],
 [[0,0], [0,0], [0,0]],
 [[0,0], [0,0], [3,3]]]

但是我收到了这个错误:

Traceback (most recent call last):
File "test.py", line 16, in <module>
d = tf.scatter_add(a,b,c,use_locking=None)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_state_ops.py", line 227, in scatter_add
name=name)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 703, in apply_op
op_def=op_def)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2312, in create_op
set_shapes_for_outputs(ret)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1704, in set_shapes_for_outputs
shapes = shape_func(op)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/state_ops.py", line 235, in _ScatterUpdateShape
indices_shape.concatenate(var_shape[1:]))
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/tensor_shape.py", line 570, in merge_with
(self, other))
ValueError: Shapes (2, 2) and (2, 2, 3, 2) are not compatible

0 个答案:

没有答案