如何在已加载的模型上设置input_tensor

时间:2018-11-17 17:13:10

标签: tensorflow keras

我已经基于VGG16创建了一个模型:

base_model = VGG16(weights='imagenet', include_top=False,
                   input_tensor=next_batch["image"],
                   input_shape=INPUT_SHAPE)
x = base_model.output
# model customization follows, not relevant

请注意,在代码段中,我指定了input_tensor的{​​{1}}。

我已经对模型进行了一些训练,然后使用tf.data.Dataset

保存了模型

现在,当我尝试使用model.save("model.h5")加载模型并继续对其进行训练时,我得到:

  

回溯(最近通话最近):   tensorflow.python.framework.errors_impl.InvalidArgumentError:您必须   用dtype float和占位符张量'input_1'的值   形状[?,512,512,3]            [[{{node input_1}} = Placeholderdtype = DT_FLOAT,shape = [?, 512,512,3],   _device =“ / job:localhost /副本:0 / task:0 / device:GPU:0”]]

问题是,如何在已加载的模型上指定load_model("model.h5")

1 个答案:

答案 0 :(得分:0)

通读负责模型加载的代码后,我提出了以下解决方案:

class InputLayerFix:
    @staticmethod
    def from_config(config):
        return InputLayer(input_tensor=next_batch["image"],
                          dtype=config["dtype"],
                          name=config["name"],
                          batch_input_shape=config["batch_input_shape"])

custom_objects = {"InputLayer": InputLayerFix}
model.load_model("model.h5", custom_objects=custom_objects)

此代码基于以下事实:反序列化例程首先检查custom_objects以实例化要实例化的类(https://github.com/tensorflow/tensorflow/blob/r1.12/tensorflow/python/keras/utils/generic_utils.py#L155