使用int64标量切片int32形状的张量

时间:2017-09-07 09:11:27

标签: python indexing tensorflow

请考虑以下代码:

x = tf.Variable([1.0,2.0,3.0])
i = tf.Variable([1], dtype = tf.int64)
x[i]
显然,

tensorflow抛出错误,因为x的形状类型与i的类型不同。我可以通过将i转换为int32来解决它,但还有其他方法吗?例如,我可以更改x的形状类型吗?

1 个答案:

答案 0 :(得分:1)

据我所知,tensorflow不支持像{numpy那样通过__getitem__进行切片。另一种方法是使用tf.gather

x = tf.Variable([1.0,2.0,3.0])
i = tf.Variable([1], dtype = tf.int64) 
tf.gather(x, i)