内存不足错误,带有keras的Tensorflow集线器

时间:2019-05-28 12:16:23

标签: python tensorflow keras hpc tensorflow-hub

即使在HPC群集上运行此程序后,我仍然出现内存不足错误。我正在做数据扩充。在此之前,它运行良好。我已经尝试减小批处理大小,控制keras会话以防止内存泄漏。我还减少了数据扩充步骤,仅进行水平和垂直翻转。我仍然面对这个问题。在此之前,数据扩充工作正常。我数据的当前形状是:

Training 1830 images
validation 317 images
Image batch shape:  (32, 299, 299, 3)
Label batch shape:  (32, 11)

这是我的程序代码:

feature_extractor_url = "https://tfhub.dev/google/imagenet/inception_v3/feature_vector/3" 
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255, validation_split=0.15)

# function to load the model
def feature_extractor(x):
  feature_extractor_module = hub.Module(feature_extractor_url)
  return feature_extractor_module(x)

IMAGE_SIZE = hub.get_expected_image_size(hub.Module(feature_extractor_url))
features_extractor_layer = layers.Lambda(feature_extractor, input_shape= IMAGE_SIZE)
features_extractor_layer.trainable = True
image_data = image_generator.flow_from_directory(str(data_root),  target_size=IMAGE_SIZE,subset='training' )
image_data_val = image_generator.flow_from_directory(str(data_root), target_size=IMAGE_SIZE,subset='validation')


#data augmentation
datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    vertical_flip=True,
    rescale=1/255,
    rotation_range=20,
    horizontal_flip=True)

K.clear_session()

datagen.fit(image_data)

model = tf.keras.Sequential([
  features_extractor_layer,
  layers.Dense(image_data.num_classes, activation='softmax')
])

result = model(image_batch)
result.shape
model.compile(optimizer=tf.keras.optimizers.Adam(), loss='categorical_crossentropy',metrics=['accuracy'])
steps_per_epoch = image_data.samples//image_data.batch_size
model.fit_generator(datagen.flow(image_batch, label_batch,  batch_size=2),steps_per_epoch, epochs=2,
                    validation_data = (item for item in image_data_val), 
                    validation_steps =image_data_val.samples/image_data_val.batch_size, callbacks = [batch_stats],
                    verbose=1)

程序总是在两个小时后失败,并且正如日志中所显示的那样,当时它正在消耗1000GB的内存。

我正在进行数据扩充以提高准确性,目前11类的准确率达到75%。由于我没有足够的数据,因此尽管要使用keras的数据增强技术。因此,我正在寻找的解决方案是:1)如何删除此错误并能够进行数据扩充? 2)通过其他什么方法可以提高模型的整体准确性? 最近几天一直坚持下去。我将不胜感激。

0 个答案:

没有答案