我有一个名为 main_decoder 的三维张量(无,9,256)
我想提取9个形状(无,256)
的张量我尝试过使用Keras 收集,以下是模式代码段:
for i in range(0,9):
sub_decoder_input = Lambda(lambda main_decoder:gather(main_decoder,(i)), name='lambda'+str(i))(main_decoder)
结果是9个λ层的形状(9,256)
如何修改它以便我可以获得或收集形状(无,256)的9个张量
感谢。
答案 0 :(得分:3)
您可以将3D张量切片为9个2D张量,并从Lambda
图层返回张量列表。
main_decoder = Input(shape=(9, 256))
sub_decoder_input = Lambda(lambda x: [x[:, i, :] for i in range(9)])(main_decoder)
print(sub_decoder_input)
[<tf.Tensor 'lambda_1/strided_slice:0' shape=(?, 256) dtype=float32>,
<tf.Tensor 'lambda_1/strided_slice_1:0' shape=(?, 256) dtype=float32>,
<tf.Tensor 'lambda_1/strided_slice_2:0' shape=(?, 256) dtype=float32>,
<tf.Tensor 'lambda_1/strided_slice_3:0' shape=(?, 256) dtype=float32>,
<tf.Tensor 'lambda_1/strided_slice_4:0' shape=(?, 256) dtype=float32>,
<tf.Tensor 'lambda_1/strided_slice_5:0' shape=(?, 256) dtype=float32>,
<tf.Tensor 'lambda_1/strided_slice_6:0' shape=(?, 256) dtype=float32>,
<tf.Tensor 'lambda_1/strided_slice_7:0' shape=(?, 256) dtype=float32>,
<tf.Tensor 'lambda_1/strided_slice_8:0' shape=(?, 256) dtype=float32>]