张量中的Tensorflow访问元素使用索引上的tenors

时间:2018-01-25 11:42:26

标签: python numpy tensorflow indexing

如何使用张量索引访问张量流Tensor中的次要元素,如下所示:

import tensorflow as tf
import numpy as np

# indexing in numpy [Working]
matrix = np.random.randint(0, 10, [100, 100])
indices = np.random.randint(0, 100, [1000, 100])
elements = matrix[indices[:, 0], indices[:, 1]]

# indexing in tensorflow [Not working]
tf_matrix = tf.constant(matrix, dtype=tf.int32)
tf_indices = tf.constant(indices, dtype=tf.int32)
tf_elements = tf_matrix[tf_indices[:, 0], tf_indices[:, 1]]  # Error

session = tf.Session()
session.run(tf_elements)

我收到这些错误:

  

tensorflow.python.framework.errors_impl.InvalidArgumentError:Shape   必须是等级1但是'strided_slice_2'的等级2(op:   'StridedSlice')输入形状:[100,100],[2,1000],[2,1000],[2]。

     

ValueError:Shape必须为1级,但'strided_slice_2'的排名为2   (op:'StridedSlice')输入形状:[100,100],[2,1000],[2,1000],   [2]。

1 个答案:

答案 0 :(得分:0)

tf_elements = tf.gather_nd(tf_matrix, tf_indices[:, 0:2])