将元数据添加到tensorflow冻结图pb

时间:2019-02-12 03:58:33

标签: python tensorflow

要共享我们训练有素的张量流网络,我们将图形冻结到.pb文件中。我们还创建一个带有一些元数据的xml文件,例如输入张量和输出张量,要应用的预处理类型,训练数据信息等。然后,通过加载图形并评估张量,使用Java或C#为模型提供服务。

为了使共享更容易,我想将此XML数据包括在.pb文件中的某个位置。有什么办法吗?一种想法是将其作为tf.Constant,但我不知道如何将其连接到普通图形。

请注意,这正在使用freeze_graph.py。新的SavedModel格式更合适吗?

1 个答案:

答案 0 :(得分:3)

首先,是的,您应该使用新的SavedModel格式,因为TF团队将支持该格式,并且也可以与Keras一起使用。您可以向模型添加一个附加端点,该端点将返回一个带有您的XML数据字符串的常数张量(如上所述)。

这很好,因为它是密封的-底层的savemodel格式无关紧要,因为您的元数据保存在计算图本身中。

请参阅以下问题的答案:Saving a TF2 keras model with custom signature defs。这个答案无法为您提供100%的Keras解决方案,因为它不能与tf.keras.models.load函数很好地互操作,因为它们将其包装在tf.Module中。幸运的是,如果添加tf.function装饰器,则在tf.keras.Model中也可以使用TF2:

class MyModel(tf.keras.Model):

  def __init__(self, metadata, **kwargs):
    super(MyModel, self).__init__(**kwargs)
    self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
    self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
    self.metadata = tf.constant(metadata)

  def call(self, inputs):
    x = self.dense1(inputs)
    return self.dense2(x)

  @tf.function(input_signature=[])
  def get_metadata(self):
    return self.metadata

model = MyModel('metadata_test')
input_arr = tf.random.uniform((5, 5, 1)) # This call is needed so Keras knows its input shape. You could define manually too
outputs = model(input_arr)

然后,您可以按以下步骤保存和加载模型:

tf.keras.models.save_model(model, 'test_model_keras')
model_loaded = tf.keras.models.load_model('test_model_keras')

最后使用model_loaded.get_metadata()来获取常量元数据张量。