没有优化器的Tensor Flow 2.0负载节省模型

时间:2019-04-07 11:13:41

标签: tensorflow

我训练了模型并按如下方式保存它:

    network.save(os.path.join(args.logdir, "cifar_model.h5") , 
    include_optimizer=False)

现在,我想加载它并像这样继续训练,但这不起作用

model = tf.keras.models.load_model("...\\cifar_model.h5", compile ="False")



    model.compile(
        optimizer=tf.keras.optimizers.RMSprop(learning_rate=0.001, decay=1e-6),
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
    )

    model.tb_callback = tf.keras.callbacks.TensorBoard(args.logdir, update_freq=1000, profile_batch=1)
    model.tb_callback.on_train_end = lambda *_: None

    datagen = tf.keras.preprocessing.image.ImageDataGenerator(
        rotation_range=15,
        width_shift_range=0.1,
        height_shift_range=0.1,
        horizontal_flip=True,
    )
    datagen.fit(cifar.train.data["images"])

    model.fit_generator(
        # cifar.train.data["images"], cifar.train.data["labels"],
        datagen.flow(cifar.train.data["images"], cifar.train.data["labels"], batch_size=args.batch_size),
        # batch_size=args.batch_size,
        steps_per_epoch=200,
        epochs=args.epochs,
        validation_data=(cifar.dev.data["images"], cifar.dev.data["labels"]),
        callbacks=[model.tb_callback],
    )

它抛出一个错误:

AttributeError: 'Network' object has no attribute 'compile'

这应该符合https://www.tensorflow.org/alpha/guide/keras/saving_and_serializing

请注意,我在没有优化程序的情况下进行保存,因此可以避免加载优化程序时出现错误。

更新: 当我知道图层的确切结构时,我就知道了如何做。 我知道,然后我可以重新创建模型并使用像这样加载的模型中的权重:

load = tf.keras.models.load_model("...\\cifar_model.h5", compile ="False")
model.set_weights(load.get_weights())

但是我不能将相同的方法应用于load.layers,如果您没有连续的层,我认为这是可能的

0 个答案:

没有答案