我使用python 3.6和tensorflow 1.14。
我得到一个动态创建的张量流模型,需要将其与其他信息一起存储。因此,我想将整个训练好的模型序列化为一个字符串。然后将该字符串与其他信息一起存储为文件(例如json或pickle)。
换句话说,我想在该示例代码中编写serialize
和deserialize
函数:
import tensorflow as tf
from tensorflow import keras
def get_trained_model() -> tf.keras.Sequential:
"""
Dummy code to build a model.
"""
model = keras.Sequential([
keras.layers.Dense(8, activation=tf.nn.relu),
keras.layers.Dense(8, activation=tf.nn.relu),
keras.layers.Dense(2, activation=tf.nn.softmax)
])
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
def serialize(model:tf.keras.Sequential) -> str:
"""
Serializes the whole tensorflow model (architecture and weights) as string, so that it can be
saved, shared and later deserialized.
"""
pass
def deserialize(model_string: str) -> tf.keras.Sequential:
"""
Deserializes a string into a working tensorflow model object.
"""
pass
model = get_trained_model()
m = serialize(model)
new_model = deserialize(m)
我当前的解决方案是使用model.save()
,将模型写入文件,从那里读取字符串,然后再次删除文件。由于此解决方案不是很干净,因此我寻求更好的解决方案。
非常感谢您的帮助或建议。