给定长度= batch_size的索引张量列表,如何索引形状为(batch_size,200,256)的张量以获得(batch_size,1,256)?

时间:2019-10-12 14:22:27

标签: python python-3.x tensorflow keras

我有LSTM层的输出,其形状为(batch_size,200,256),其中200是令牌序列的长度,而256是LSTM输出尺寸。 我还具有另一个形状为(batch_size)的张量,它是要从批次中的每个样本序列中切出的标记的索引列表。

如果令牌索引不是-1,我将切出令牌向量表示形式(长度= 256)。如果令牌索引为-1,我将给出零个向量(长度= 256)。

预期的输出结果的形状为(batch_size,1,256)。我该怎么办?

谢谢

这是我到目前为止尝试过的

bidir = concatenate([forward, backward]) # shape = (batch_size, 200, 256) 
dropout = Dropout(params['dropout_rate'])(bidir)
def slice_by_tensor(x):
    matrix_to_slice = x[0]
    index_tensor = x[1]


    out_tensor = tf.where(index_tensor == -1, 
                          tf.zeros(tf.shape(tf.gather(matrix_to_slice, 
                                                      index_tensor, axis=1))), 
                          tf.gather(matrix_to_slice, index_tensor, axis=1))



    return out_tensor


representation_stack0 = Lambda(lambda x: slice_by_tensor(x))([dropout,stack_idx0]) 
# stack_idx0 shape is (batch_size) 
# I got output with shape (batch_size, batch_size, 256) with this code

1 个答案:

答案 0 :(得分:0)

a=tf.reshape(tf.range(2*3*4),shape=(2,3,4))
#     [[[ 0,  1,  2,  3],
#        [ 4,  5,  6,  7],
#        [ 8,  9, 10, 11]],

#      [[12, 13, 14, 15],
#      [16, 17, 18, 19],
#       [20, 21, 22, 23]]]

b=tf.constant([-1,2]) 

aa=tf.pad(a,[[0,0],[1,0],[0,0]]) 

bb=b+1 

index=tf.stack([tf.range(tf.size(b)),bb],axis=-1) 
res=tf.expand_dims(tf.gather_nd(aa, index),axis=1)
#[[[ 0,  0,  0,  0]],
#[[20, 21, 22, 23]]]

当index为-1时,我们需要像张量一样的零。因此,我们可以先沿第二个轴填充原始张量。然后将索引增加1。此后,使用tf.gather_nd将返回答案。