在其他维度上重复keras(tensorflow)模型

时间:2019-03-14 12:58:53

标签: python tensorflow keras

假设我有一个模型,该模型将形状为[n,10]的张量映射到形状为[n,2]的张量,其中n是批量大小。如何重复模型,以使生成的模型接受形状为[n,k,10]的输入张量并输出形状为[n,k,2]的张量?模型的k个版本应共享所有权重。

1 个答案:

答案 0 :(得分:1)

您可以执行以下操作:

input_ = Input((k, model.input.shape[1]))
input_as_list = Lambda(lambda x: tf.unstack(x, axis=1))(input_)
model_outputs = [model(x) for x in input_as_list] 
model_outputs = [Lambda(lambda x: K.expand_dims(x, axis=1))(y) for y in model_outputs]
concat_output = Concatenate(axis=1)(model_outputs)
new_model = Model(input_, concat_output)