TensorFlow:将py_func保存到.pb文件

时间:2019-05-21 07:41:51

标签: python tensorflow

我尝试建立一个张量流模型-在这里我使用tf.py_func在普通的python代码中创建部分代码。问题是,当我将模型保存到.pb文件时,.pb文件本身很小,并且不包含py_func:0张量。当我尝试从.pb文件加载并运行模型时,出现以下错误:get ValueError:未找到回调pyfunc_0。

当我不保存并加载为.pb文件时可以使用

有人能帮助你吗?这对我来说非常重要,给了我几个不眠之夜。

model_version = "465555564"
tensorboard = TensorBoard(log_dir='./logs', histogram_freq = 0, write_graph = True, write_images = False)

sess = tf.Session()
K.set_session(sess)
K.set_learning_phase(0)

def my_func(x):
    some_function

input = tf.placeholder(tf.float32)
y = tf.py_func(my_func, [input], tf.float32)

prediction_signature = tf.saved_model.signature_def_utils.predict_signature_def({"inputs": input}, {"prediction": y})
builder = saved_model_builder.SavedModelBuilder('./'+model_version)
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
      sess, [tag_constants.SERVING],
      signature_def_map={
           signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:prediction_signature,
      },
      legacy_init_op=legacy_init_op)

builder.save()

1 个答案:

答案 0 :(得分:0)

一种方法,可以使用tf.py_func保存TF模型,但是您必须使用SavedModel 无需来完成。

TF有2种模型保存级别:检查点和SavedModels。有关更多详细信息,请参见this answer,但请在此处引用:

  
      
  • 检查点包含TensorFlow模型中(某些)变量的值。它是由Saver创建的。要使用检查点,您需要具有兼容的TensorFlow Graph,其Variable与检查点中的Variable具有相同的名称。
  •   
  • SavedModel更为全面:它包含一组Graph(实际上是MetaGraphs,用于保存集合等),以及应该作为检查点的检查点。与这些Graph兼容,并且与运行模型所需的任何资产文件(例如,词汇文件)兼容。对于它包含的每个MetaGraph,它还存储一组签名。签名定义(命名)输入和输出张量。
  •   

tf.py_func操作 不能用SavedModel(在this page in the docs上注明)保存,这是您在此处尝试做的。这是有充分的理由的。 SavedModel应该完全独立于原始代码,并能够以其他可以反序列化的语言来加载。这样可以通过ML Engine之类的东西来加载模型,这些东西可能是用C ++或类似的语言编写的。问题在于它无法序列化任意Python代码,因此py_func是行不通的。

只要您可以使用Python,就可以使用检查点来解决此问题。您将无法获得SavedModel提供的独立性。您可以在使用tf.train.Saver进行训练后保存检查点,然后在新的Session中重新构建整个图形并使用该Saver进行加载。甚至还有一种方法可以在ML Engine中使用该代码,该代码以前专门用于SavedModel。您可以使用custom prediction routines来回避对SavedModel的需求。

有关在the docs中保存/恢复模型的更多信息。