从Keras的3-d张量中收集2-d张量列表

时间:2017-11-17 09:53:10

标签: machine-learning tensorflow nlp keras

我有一个名为 main_decoder 的三维张量(无,9,256)

我想提取9个形状(无,256)

的张量

我尝试过使用Keras 收集,以下是模式代码段:

for i in range(0,9):
    sub_decoder_input = Lambda(lambda main_decoder:gather(main_decoder,(i)), name='lambda'+str(i))(main_decoder)

结果是9个λ层的形状(9,256)

如何修改它以便我可以获得或收集形状(无,256)的9个张量

感谢。

1 个答案:

答案 0 :(得分:3)

您可以将3D张量切片为9个2D张量,并从Lambda图层返回张量列表。

main_decoder = Input(shape=(9, 256))
sub_decoder_input = Lambda(lambda x: [x[:, i, :] for i in range(9)])(main_decoder)

print(sub_decoder_input)
[<tf.Tensor 'lambda_1/strided_slice:0' shape=(?, 256) dtype=float32>,
 <tf.Tensor 'lambda_1/strided_slice_1:0' shape=(?, 256) dtype=float32>,
 <tf.Tensor 'lambda_1/strided_slice_2:0' shape=(?, 256) dtype=float32>,
 <tf.Tensor 'lambda_1/strided_slice_3:0' shape=(?, 256) dtype=float32>,
 <tf.Tensor 'lambda_1/strided_slice_4:0' shape=(?, 256) dtype=float32>,
 <tf.Tensor 'lambda_1/strided_slice_5:0' shape=(?, 256) dtype=float32>,
 <tf.Tensor 'lambda_1/strided_slice_6:0' shape=(?, 256) dtype=float32>,
 <tf.Tensor 'lambda_1/strided_slice_7:0' shape=(?, 256) dtype=float32>,
 <tf.Tensor 'lambda_1/strided_slice_8:0' shape=(?, 256) dtype=float32>]