当通过发电机馈送大量数据时,保存中间层权重?

时间:2018-02-03 20:22:31

标签: tensorflow deep-learning keras

img_width, img_height = 299, 299
batch_size = 6
epochs = 1
classes = 12



train_datagen = ImageDataGenerator(preprocessing_function = preprocess)

train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size = (img_height, img_width),
batch_size = batch_size, 
class_mode = 'categorical')


base_model = Xception(weights='imagenet', include_top=False)
x =base_model.predict_generator(train_generator, steps=None,
                        max_queue_size=10, workers=1,
                        use_multiprocessing=False, verbose=0)

此方法的问题在于x在运行时被强制保留批量的所有权重,并且由于内存问题最终导致系统崩溃。 所以我无法将其保存为.npy。文件

有没有办法每批保存重量?

1 个答案:

答案 0 :(得分:1)

可以通过以下方式实现:

import math

number_of_examples = len(train_generator.filenames) # number of images
number_of_generator_steps = math.ceil(number_of_examples / (1.0 * batch_size))

current_iteration = 0
for x, _ in train_generator:
    prediction = model.predict(x)
    # here comes your custom saving function.
    current_iteration += 1
    if current_iteration == number_of_generator_steps:
        break