例如,如果我们有:
a = tf.constant(np.eye(5))
a
<tf.Tensor 'Const:0' shape=(5, 5) dtype=float64>
a[0,:]
<tf.Tensor 'strided_slice:0' shape=(5,) dtype=float64>
张量a
的切片会将尺寸2
的原始数量减少为1
我怎么能直接获得排名不变的切片??
a[0,:]
<tf.Tensor 'strided_slice:0' shape=(1,5) dtype=float64>
(tf.expand_dims(a[0,:], axis=0)
可以工作,但是还有更直接,更简单的方法吗?)
答案 0 :(得分:1)
至少有两种直接方法,与NumPy(related question)中可用的方法非常相似。
a[x:x+1]
None
添加一个轴:a[None, x]
a[0:1]
<tf.Tensor 'strided_slice_1:0' shape=(1, 5) dtype=float64>
一些实际的张量运行显示了预期的结果。
with tf.Session() as sess:
sess.run(a[0])
sess.run(a[0:1])
sess.run(a[None, 0])
array([1., 0., 0., 0., 0.])
array([[1., 0., 0., 0., 0.]])
array([[1., 0., 0., 0., 0.]])