TensorFlow strided_slice不会在整个范围内向后迭代

时间:2019-04-30 23:05:02

标签: python tensorflow

我有以下测试。我想反转行,最后一行是第一行,依此类推。

    x_val = np.arange(3 * 2 * 3).astype(np.int64).reshape((3, 2, 3))
    print(x_val)
    x = tf.placeholder(tf.float32, x_val.shape, name='input')
    x_ = tf.strided_slice(x, [3], [0], [-1])
    _ = tf.identity(x_, name='output')
    with tf.Session() as sess:
        variables_lib.global_variables_initializer().run()
        output_dict = []
        for out_name in ['output:0']:
            output_dict.append(sess.graph.get_tensor_by_name(out_name))
        expected = sess.run(output_dict, feed_dict={"input:0": x_val})
        print('test_strided_slice1', expected[0].shape, expected[0])

我希望我的输出是:

[
  [
    [12. 13. 14.]
    [15. 16. 17.]
  ]
  [
    [ 6.  7.  8.]
    [ 9. 10. 11.]
  ]
  [
    [ 0  1  2]
    [ 3  4  5]
  ]
]

但是我得到了

[
  [
    [12. 13. 14.]
    [15. 16. 17.]
  ]
  [
    [ 6.  7.  8.]
    [ 9. 10. 11.]
  ]
]

您会看到错过了现在应该是第一行的第一行。

如果我像0:3:1那样逐步通过,则会得到所有行。但是,如果我倒车,我少拿一个。

不提供“ end”索引将导致测试失败。将'end'设置为-1也会导致空输出。

关于如何完成此操作的任何建议?

1 个答案:

答案 0 :(得分:1)

在大多数情况下,通过Python索引语法使用tf.strided_slice更方便,所以您可以这样做:

x_ = x[::-1]

但是,可以直接使用tf.strided_slice做您想做的事情。为此,您需要使用end_mask参数。在此整数值中,如果设置了第i个比特(从最低有效位开始),则将忽略第i个维度的相应end值,并尽可能地采用分片。因此,您可以这样做:

x_ = tf.strided_slice(x, [3], [0], [-1], end_mask=1)

请注意,我已经将4中的begin更改为3,因为这是切片开始的实际索引(尽管它也适用于4) 。如果您只是想从头到尾进行切片,也可以使用start_mask,其作用类似于end_mask

x_ = tf.strided_slice(x, [0], [0], [-1], start_mask=1, end_mask=1)

一个小例子:

import tensorflow as tf

with tf.Graph().as_default():
    x = tf.reshape(tf.range(18), [3, 2, 3])
    x_ = tf.strided_slice(x, [0], [0], [-1], start_mask=1, end_mask=1)
    with tf.Session() as sess:
        print(sess.run(x_))

输出:

[[[12 13 14]
  [15 16 17]]

 [[ 6  7  8]
  [ 9 10 11]]

 [[ 0  1  2]
  [ 3  4  5]]]