我正在使用Tensorflow上的Keras功能API创建一个广泛而深入的模型。
当我尝试合并两个模型时,发生以下错误。
-------------------------------------------------- ---------------------------- ValueError Traceback(最近的呼叫 最后)在() 1 merged_out = tf.keras.layers.concatenate([wide_model.output,deep_model.output]) 2 merged_out = tf.keras.layers.Dense(1)(merged_out) ----> 3 Combined_model = tf.keras.Model(inputs = wide_model.input + [deep_model.input],outputs = merged_out) 4打印(combined_model.summary())
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/training.py 在 init (自身,* args,** kwargs)中 111 112 def init (自身,* args,** kwargs): -> 113超级(模型,自我)。初始化(* args,** kwargs) 114#为迭代器get_next op创建一个缓存。 115 self._iterator_get_next = weakref.WeakKeyDictionary()
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py 在 init (自身,* args,** kwargs)中 77(以kwargs为单位的“输入”和以kwargs为单位的“输出”): 78#图形网络 ---> 79 self._init_graph_network(* args,** kwargs) 其他80个 81#子类化网络
/usr/local/lib/python3.6/dist-packages/tensorflow/python/training/checkpointable/base.py 在_method_wrapper中(self,* args,** kwargs) 362 self._setattr_tracking = False#pylint:disable =受保护的访问 363尝试: -> 364方法(self,* args,** kwargs) 365最后: 366 self._setattr_tracking = previous_value#pylint:disable =受保护的访问
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/network.py 在_init_graph_network中(自己,输入,输出,名称) 193'必须来自
tf.layers.Input
。 ' 194'已收到:'+ str(x)+ -> 195'(缺少上一层元数据)。) 196#检查x是输入张量。 197#pylint:disable = protected-accessValueError:模型的输入张量必须来自
tf.layers.Input
。 收到:Tensor(“ add_1:0”,shape =(1,?,163),dtype = float32)(丢失 前一层的元数据)。
这是将两者串联的代码。
merged_out = tf.keras.layers.concatenate([wide_model.output, deep_model.output])
merged_out = tf.keras.layers.Dense(1)(merged_out)
combined_model = tf.keras.Model(inputs=wide_model.input + [deep_model.input], outputs=merged_out)
print(combined_model.summary())
对于每个模型的输入,我尝试将tf.layers.Input
与
inputs = tf.placeholder(tf.float32, shape=(None,X_resampled.shape[1]))
deep_inputs = tf.keras.Input(tensor=(inputs))
使它们成为this page所述的tf.layers.Input
。
但是我仍然面临着同样的问题。
我正在使用tensorflow == 1.10.0
有人可以帮我解决这个问题吗?
谢谢!
答案 0 :(得分:1)
在inputs=wide_model.input + [deep_model.input]
中,wide.model.input
可能不是列表,因此您要传递新的Add
张量而不是输入列表。尝试通过inputs=[wide_model.input] + [deep_model.input]
代替