Tensorflow:如何在张数不变的情况下切片张量?

时间:2018-08-03 09:57:12

标签: python tensorflow

例如,如果我们有:

a = tf.constant(np.eye(5))
a
<tf.Tensor 'Const:0' shape=(5, 5) dtype=float64>
a[0,:]
<tf.Tensor 'strided_slice:0' shape=(5,) dtype=float64>

张量a的切片会将尺寸2的原始数量减少为1

我怎么能直接获得排名不变的切片??

a[0,:]
<tf.Tensor 'strided_slice:0' shape=(1,5) dtype=float64>

tf.expand_dims(a[0,:], axis=0)可以工作,但是还有更直接,更简单的方法吗?)

1 个答案:

答案 0 :(得分:1)

至少有两种直接方法,与NumPy(related question)中可用的方法非常相似。

  1. 在尺寸为1的该轴上获取一个范围:a[x:x+1]
  2. 使用None添加一个轴:a[None, x]
a[0:1]
<tf.Tensor 'strided_slice_1:0' shape=(1, 5) dtype=float64>

一些实际的张量运行显示了预期的结果。

with tf.Session() as sess:
    sess.run(a[0])
    sess.run(a[0:1])
    sess.run(a[None, 0])
array([1., 0., 0., 0., 0.])
array([[1., 0., 0., 0., 0.]])
array([[1., 0., 0., 0., 0.]])