请考虑以下代码:
x = tf.Variable([1.0,2.0,3.0])
i = tf.Variable([1], dtype = tf.int64)
x[i]
显然,tensorflow抛出错误,因为x的形状类型与i的类型不同。我可以通过将i转换为int32来解决它,但还有其他方法吗?例如,我可以更改x的形状类型吗?
答案 0 :(得分:1)
据我所知,tensorflow不支持像{numpy那样通过__getitem__
进行切片。另一种方法是使用tf.gather
:
x = tf.Variable([1.0,2.0,3.0])
i = tf.Variable([1], dtype = tf.int64)
tf.gather(x, i)