TensorFlow:使用索引从张量中获取子张量

时间:2016-12-23 01:35:12

标签: tensorflow

Atensorflow.tensor,其形状为(2261,)

我希望从A的{​​{1}}:[10,20,30]

中获得一个新的张量

我尝试了以下所有内容,但都没有效果:

A[[10,20,30]]
# *** ValueError: Index out of range using input dim 1; input has only 1 dims for 'strided_slice' (op: 'StridedSlice') with input shapes: [2261], [3], [3], [3].

A[10,20,30]
# same error as above 

A[numpy.array([10,20,30])]
# *** ValueError: Shape must be rank 1 but is rank 2 for 'strided_slice' (op: 'StridedSlice') with input shapes: [2261], [1,3], [1,3], [1].

A[10]
# <tf.Tensor 'strided_slice:0' shape=() dtype=float32> - not an error but a shapeless tensor

A[tensorflow.constant(10)]
# same problem as above

为什么这些不起作用,我该怎么办?

3 个答案:

答案 0 :(得分:4)

C = tf.nn.embedding_lookup(A, B)

其中B是张量,值为[10,20,30]

供参考:https://www.tensorflow.org/api_docs/python/nn/embeddings

答案 1 :(得分:3)

我认为你要找的是聚集功能。

B = tf.constant([10, 20, 30])
tf.gather(A, B)

https://www.tensorflow.org/api_docs/python/tf/gather

答案 2 :(得分:0)

我不认为在TensorFlow中支持这样的花式索引。密切注意https://github.com/tensorflow/tensorflow/issues/206更新(也许还有其他地方)。

如果您想查看 的可用内容,看起来他们有一些关于__tensor.__getitem__的文档。