如果给定形状为a
的矩阵(5,3)
和形状为b
的索引数组(5,)
,我们可以轻松获取相应的向量c
,< / p>
c = a[np.arange(5), b]
但是,我不能用tensorflow做同样的事情,
a = tf.placeholder(tf.float32, shape=(5, 3))
b = tf.placeholder(tf.int32, [5,])
# this line throws error
c = a[tf.range(5), b]
Traceback(最近一次调用最后一次):文件“”,第1行,in 文件 “〜/ anaconda2 / lib中/ python2.7 /站点包/ tensorflow /蟒蛇/ OPS / array_ops.py” 第513行,在_SliceHelper中 名称=名)
文件 “〜/ anaconda2 / lib中/ python2.7 /站点包/ tensorflow /蟒蛇/ OPS / array_ops.py” 第671行,strided_slice shrink_axis_mask = shrink_axis_mask)文件“〜/ anaconda2 / lib / python2.7 / site-packages / tensorflow / python / ops / gen_array_ops.py”, 第3688行,strided_slice shrink_axis_mask = shrink_axis_mask,name = name)文件“〜/ anaconda2 / lib / python2.7 / site-packages / tensorflow / python / framework / op_def_library.py”, 第763行,在apply_op中 op_def = op_def)文件“〜/ anaconda2 / lib / python2.7 / site-packages / tensorflow / python / framework / ops.py”, 第2397行,在create_op中 set_shapes_for_outputs(ret)文件“〜/ anaconda2 / lib / python2.7 / site-packages / tensorflow / python / framework / ops.py”, 第1757行,在set_shapes_for_outputs中 shapes = shape_func(op)文件“〜/ anaconda2 / lib / python2.7 / site-packages / tensorflow / python / framework / ops.py”, 第1707行,在call_with_requiring中 return call_cpp_shape_fn(op,require_shape_fn = True)文件“〜/ anaconda2 / lib / python2.7 / site-packages / tensorflow / python / framework / common_shapes.py”, 第610行,在call_cpp_shape_fn中 debug_python_shape_fn,require_shape_fn)文件“〜/ anaconda2 / lib / python2.7 / site-packages / tensorflow / python / framework / common_shapes.py”, 第675行,在_call_cpp_shape_fn_impl中 raise ValueError(err.message)ValueError:Shape必须为1,但对于'strided_slice_14'(op:'StridedSlice'),输入为2 形状:[5,3],[2,5],[2,5],[2]。
我的问题是,如果我不能使用上述方法在numpy中产生tensorflow的预期结果,我该怎么办?
答案 0 :(得分:6)
此功能目前尚未在TensorFlow中实施。 GitHub issue #4638正在跟踪NumPy风格的“高级”索引的实现。但是,您可以使用tf.gather_nd()
运算符来实现您的程序:
a = tf.placeholder(tf.float32, shape=(5, 3))
b = tf.placeholder(tf.int32, (5,))
row_indices = tf.range(5)
# `indices` is a 5 x 2 matrix of coordinates into `a`.
indices = tf.transpose([row_indices, b])
c = tf.gather_nd(a, indices)