如何在张量中移动值

时间:2018-01-11 20:11:06

标签: tensorflow

我有张量T形状[batch_size,A],其值和张量S的形状[batch_size]带有移位参数。

我想将T [b]中的值移位到右边的S [b]位置,应该删除T [b]的最后一个S [b]元素,并将新元素设置为0。

所以基本上想做点什么:

for i in range(batch_size):
  T[i] = zeros[:S[i]] + T[i, :A-S[i]]

示例:

For:
T = [[1, 2, 3], [4, 5, 6]]
S = [1, 2]

Return:
T' = [[0, 1, 2], [0, 0, 4]]

有没有简单的方法呢?

2 个答案:

答案 0 :(得分:1)

如果您在Tensorflow 2中工作,则可以为此目的使用tf.roll

“元素通过以下元素正向(朝着较大的索引)移动 沿轴尺寸的偏移量。负偏移值 会沿相反方向移动元素。滚动元素 通过最后一个位置将绕到第一个,反之亦然。 可以指定沿多个轴的多个移位。”

tf.roll(
       input, shift, axis, name=None
)

# 't' is [0, 1, 2, 3, 4]
roll(t, shift=2, axis=0) ==> [3, 4, 0, 1, 2]

# shifting along multiple dimensions
# 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
roll(t, shift=[1, -2], axis=[0, 1]) ==> [[7, 8, 9, 5, 6], [2, 3, 4, 0, 1]]

# shifting along the same axis multiple times
# 't' is [[0, 1, 2, 3, 4], [5, 6, 7, 8, 9]]
roll(t, shift=[2, -3], axis=[1, 1]) ==> [[1, 2, 3, 4, 0], [6, 7, 8, 9, 5]]

答案 1 :(得分:0)

您可以为此目的使用tf.concat和tf.stack:

T_shift = tf.zeros((batch_size, A), tf.float32)
tmp = []

for i in xrange(batch_size):
    tmp.append(tf.concat([T_shift[i, :S[i, 0]],T[i, :17 - S[i,0]]], axis = 0))
T_shift = tf.stack(tmp)