如何在tensorflow简单保存中指定变量

时间:2019-06-19 17:02:41

标签: tensorflow keras

我尝试使用简单的save方法保存我的tensorflow模型失败。

我已经使用keras构建了一个模型,并成功地对其进行了训练,准确度达到了88%。我现在正在尝试保存该模型以便我们可以使用它,但是我需要的功能(简单保存)尚不清楚如何指定要传入的变量。

会话和导出目录足够清晰,但是输入和输出是神秘的。我相信,因为我使用过Keras,所以这些变量被keras的抽象所隐藏,而Tensorflow中有关简单保存的文档没有提供任何解释。

作为一个冰雹,我只是将Z设置为y,但这显然是错误的。我是否需要设置输出变量Z,如果是,它是什么类型?

不确定是否足够的代码来深入探讨此问题。甚至指向正确的文档也将大有助益。

import tensorflow as tf
session =  tf.keras.backend.get_session()
export_dir = "/Users/somedir/"
z = np.array([])
tf.saved_model.simple_save(session,
            export_dir,
            inputs={"x": X, "y": y},
            outputs={"z": z})

X是我的数据集-所有自变量的数组。 Y是结果(因变量)。我没有z的其他候选者,因此将其设置为空数组。

我得到AttributeError:'numpy.ndarray'对象没有属性'get_shape'

1 个答案:

答案 0 :(得分:0)

结果证明,您可以查询模型本身以获取其输入和输出。

别忘了导入正确的库:

import time
import tensorflow as tf
import tensorflow.python.saved_model

然后设置一个导出路径变量,为方便起见,它带有时间戳,因此您可以一次又一次地运行它:

export_path = "/somedirectory/{}".format(time.strftime("%Y%m%d_%H%M%S"))

然后在get_session()块中,将执行以下操作:

with tf.keras.backend.get_session() as sess:
    tf.saved_model.simple_save(
        sess,
        export_path,
        inputs={t.name:t for t in model.inputs},
        outputs={t.name:t for t in model.outputs})