用nfy中的tf.slice检测越界切片

时间:2017-07-11 09:30:53

标签: python numpy tensorflow slice

在tensorflow中,我试图使用tf.slice,但是as its documentation states,它需要切片适合输入数组。例如,如果你试图切割张量[1,2,3,4]的前5个位置,它将崩溃。我想拥有与python列表或numpy数组相同的功能,其中切片可以获得原始数组和您要求的切片的交集。例如,如果你要求[1,2,3,4]中的2到6个位置,你将获得[2,3,4]。

我怎样才能在张量流中做到这一点?

谢谢!

2 个答案:

答案 0 :(得分:3)

你可以使用tensorflow的python切片运算符,它比tf.slice略强一些,特别是一些绑定检查的行为与numpy类似。

x = tf.zeros((10,))
y = tf.slice(x, [5], [15])
print(y.shape)
# (15,)
z = x[5:42]
print(z.shape)
# (5,)

答案 1 :(得分:0)

您可以使用tf.clip_by_value

def robust_slice( t, begin, size, name=None ):
    with tf.name_scope('robust_slice'):
        new_size = tf.clip_by_value(size + begin, clip_value_min = 0, clip_value_max = t.shape()) - begin
        return tf.slice(t, begin, new_size, name)

(没有经过测试,但这样的事情应该可以完成这项工作)