将tensorflow.keras模型保存到字符串

时间:2019-10-07 14:19:38

标签: python python-3.x tensorflow

我使用python 3.6和tensorflow 1.14。

我得到一个动态创建的张量流模型,需要将其与其他信息一起存储。因此,我想将整个训练好的模型序列化为一个字符串。然后将该字符串与其他信息一起存储为文件(例如json或pickle)。

换句话说,我想在该示例代码中编写serializedeserialize函数:

import tensorflow as tf
from tensorflow import keras

def get_trained_model() -> tf.keras.Sequential:
    """
    Dummy code to build a model.
    """
    model = keras.Sequential([
        keras.layers.Dense(8, activation=tf.nn.relu),
        keras.layers.Dense(8, activation=tf.nn.relu),
        keras.layers.Dense(2, activation=tf.nn.softmax)
    ])
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    return model

def serialize(model:tf.keras.Sequential) -> str:
    """
    Serializes the whole tensorflow model (architecture and weights) as string, so that it can be
    saved, shared and later deserialized.
    """
    pass

def deserialize(model_string: str) -> tf.keras.Sequential:
    """
    Deserializes a string into a working tensorflow model object.
    """
    pass

model = get_trained_model()
m = serialize(model)
new_model = deserialize(m)

我当前的解决方案是使用model.save(),将模型写入文件,从那里读取字符串,然后再次删除文件。由于此解决方案不是很干净,因此我寻求更好的解决方案。

非常感谢您的帮助或建议。

0 个答案:

没有答案