我正在使用带有Theano后端的Keras 2.1.3。我有一个名为model
的Keras模型,我试图将其保存为ML Engine可接受的格式。谷歌搜索后,我发现了这一点:
from keras.models import Model
import tensorflow as tf
from tensorflow.python.saved_model import builder, tag_constants, signature_constants
from tensorflow.python.saved_model.signature_def_utils_impl import predict_signature_def
tf.keras.backend.clear_session()
sess = tf.Session()
tf.keras.backend.set_session(sess)
# build a copy but without the learning nodes
tf.keras.backend.set_learning_phase(0) # disables creation of dropout
new = Model.from_config(model.get_config())
new.set_weights(model.get_weights())
# export saved model
model_builder = builder.SavedModelBuilder(export_path)
signature = predict_signature_def(
inputs={'sequence': new.input},
outputs={'outputs': new.output})
with tf.keras.backend.get_session() as sess:
sdefm = {signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: signature}
model_builder.add_meta_graph_and_variables(sess=sess,
tags=[tag_constants.SERVING],
signature_def_map=sdefm)
model_builder.save()
但是我得到了一个奇怪的回溯:
File "/Users/...
outputs={'outputs': new.output})
File "/Users/.../tensorflow/python/saved_model/signature_def_utils_impl.py", line 177, in predict_signature_def
for key, tensor in inputs.items()}
File "/Users/.../tensorflow/python/saved_model/signature_def_utils_impl.py", line 177, in <dictcomp>
for key, tensor in inputs.items()}
File "/Users/.../tensorflow/python/saved_model/utils_impl.py", line 45, in build_tensor_info
tensor_shape=tensor.get_shape().as_proto())
AttributeError: 'TensorVariable' object has no attribute 'get_shape'
Exception TypeError: TypeError("'NoneType' object is not callable",) in <bound method Session.__del__ of <tensorflow.python.client.session.Session object at 0x12ff1f150>> ignored
这看起来像我的tensor
类型的TensorVariable
叫做get_shape()
,并且没有。
我做错了什么?我觉得这是一个非常常见的用例,所以不应该太疯狂了吗?
谢谢!