要共享我们训练有素的张量流网络,我们将图形冻结到.pb
文件中。我们还创建一个带有一些元数据的xml文件,例如输入张量和输出张量,要应用的预处理类型,训练数据信息等。然后,通过加载图形并评估张量,使用Java或C#为模型提供服务。
为了使共享更容易,我想将此XML数据包括在.pb
文件中的某个位置。有什么办法吗?一种想法是将其作为tf.Constant,但我不知道如何将其连接到普通图形。
请注意,这正在使用freeze_graph.py
。新的SavedModel格式更合适吗?
答案 0 :(得分:3)
首先,是的,您应该使用新的SavedModel格式,因为TF团队将支持该格式,并且也可以与Keras一起使用。您可以向模型添加一个附加端点,该端点将返回一个带有您的XML数据字符串的常数张量(如上所述)。
这很好,因为它是密封的-底层的savemodel格式无关紧要,因为您的元数据保存在计算图本身中。
请参阅以下问题的答案:Saving a TF2 keras model with custom signature defs。这个答案无法为您提供100%的Keras解决方案,因为它不能与tf.keras.models.load函数很好地互操作,因为它们将其包装在tf.Module
中。幸运的是,如果添加tf.function装饰器,则在tf.keras.Model
中也可以使用TF2:
class MyModel(tf.keras.Model):
def __init__(self, metadata, **kwargs):
super(MyModel, self).__init__(**kwargs)
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
self.metadata = tf.constant(metadata)
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
@tf.function(input_signature=[])
def get_metadata(self):
return self.metadata
model = MyModel('metadata_test')
input_arr = tf.random.uniform((5, 5, 1)) # This call is needed so Keras knows its input shape. You could define manually too
outputs = model(input_arr)
然后,您可以按以下步骤保存和加载模型:
tf.keras.models.save_model(model, 'test_model_keras')
model_loaded = tf.keras.models.load_model('test_model_keras')
最后使用model_loaded.get_metadata()
来获取常量元数据张量。