如何将具有动态输入形状的keras自定义模型保存为SaveModel格式?

时间:2020-09-18 09:35:43

标签: python tensorflow keras

我有一个具有动态输入形状(灵活的第二维)的自定义模型。

我需要将其保存为SaveModel格式。但它只保存一个签名(第一个使用)。

加载后尝试使用其他签名时-我收到错误消息:

Python输入与input_signature不兼容

我的代码如下:

seq_len = 2
batch_size = 3
import tensorflow as tf

class CustomModule(tf.keras.Model):

  def __init__(self):
    super(CustomModule, self).__init__()
    self.v = tf.Variable(1.)

  #@tf.function
  def call(self, x):
    return x * self.v

module_output = CustomModule()
input = tf.random.uniform([batch_size, seq_len], dtype=tf.float32)
input2 = tf.random.uniform([batch_size, seq_len+1], dtype=tf.float32)
output = tf.random.uniform([batch_size, seq_len], dtype=tf.float32)
output2 = tf.random.uniform([batch_size, seq_len+1], dtype=tf.float32)

optimizer = tf.keras.optimizers.SGD()
training_loss = tf.keras.losses.MeanSquaredError()
module_output.compile(optimizer=optimizer, loss=training_loss)
#hist = module_output.fit(input, output, epochs=1, steps_per_epoch=1, verbose=0)
#hist = module_output.fit(input2, output2, epochs=1, steps_per_epoch=1, verbose=0)

a = module_output(input)               # the first signature
a = module_output(input2)              # the second signature
module_output.save('savedModel/', True, False)
module_output = tf.keras.models.load_model('savedModel/')
a = module_output(input)               # <= it works
a = module_output(input2)              # <= the error is here

我如何使其工作?

编辑: 这是一个玩具的例子。我无法使用功能性API编写模型,因为实际模型太复杂了。

2 个答案:

答案 0 :(得分:1)

尝试使用不同的输入形状和功能性API创建模型:

def create_model(batch_size, seq_len):
   inputs = tf.keras.Input(shape=(batch_size, seq_len)) #input layer
   x = tf.keras.layers...(inputs) # next layer
   x = tf.keras.layers...(x)
   ...
   outputs = tf.keras.layers...(x) # output layer
   model = tf.keras.Model(inputs = inputs, outputs = outputs)
   model.compile(...)
   return model

自从Model继承以来,如果替换了模型声明行,它应该可以工作。

答案 1 :(得分:0)

您可以使用 com.crystaldecisions.sdk.occa.report.application.ReportClientDocument 手动指定装饰 call 函数的输入形状/数据类型。

在您展示的示例中,您可以按如下方式修饰 @tf.function(input_signature=...) 函数:

call