我正在Google Colab中编写Jupyter笔记本,该笔记本使用keras
将预先训练的模型拟合到图像分类数据集中。即使我认为我的手工生成器不应导致内存问题,但RAM却已填满并导致Colab会话崩溃。
我试图尽可能地精简我的代码(第4、5和6点是相关的):
1)一些进口
import os
import tarfile
import urllib.request
import numpy as np
import pandas as pd
import skimage.io
import tensorflow as tf
2)数据集准备
dataset = {
'name': 'oxford-102-flowers',
'url': 'https://s3.amazonaws.com/fast-ai-imageclas/oxford-102-flowers.tgz',
'num_classes': 102,
}
base_dir = os.getcwd()
data_dir = os.path.join(base_dir, 'data')
dataset_dir = os.path.join(data_dir, dataset['name'])
os.makedirs(data_dir)
tgz_file = os.path.join(data_dir, dataset['name']+'.tgz')
urllib.request.urlretrieve(dataset['url'], tgz_file)
tar = tarfile.open(tgz_file, 'r:gz')
tar.extractall(data_dir)
tar.close()
os.remove(tgz_file)
train_txt = os.path.join(dataset_dir, 'train.txt')
train_df = pd.read_csv(train_txt, sep=' ', header=None, names=('file', 'label'))
3)模型设置
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
base_model = tf.keras.applications.MobileNetV2(include_top=False)
pooling_dense_model = tf.keras.Sequential([
base_model,
tf.keras.layers.GlobalAveragePooling2D(),
tf.keras.layers.Dense(dataset['num_classes']),
])
pooling_dense_model.compile(
loss=loss,
optimizer='adam',
metrics=['acc'],
)
4)建立一个发电机
def make_train_generator():
while True:
for i, record in train_df.iterrows():
filename = os.path.join(dataset_dir, record['file'])
image = skimage.io.imread(os.path.join(dataset_dir, filename))
normalised_image = image/255. - 0.5
image_batch = normalised_image[np.newaxis, ...]
label_batch = np.array((record['label'],))
yield image_batch, label_batch
5)检查生成器本身不会产生内存泄漏
train_generator = make_train_generator()
train_len = len(train_df)
for i in range(2*train_len):
train_example = next(train_generator)
6)训练模型
pooling_dense_model_history = pooling_dense_model.fit_generator(
generator=train_generator,
steps_per_epoch=train_len,
epochs=2,
verbose=1,
)
现在点5)不会增加RAM消耗,而在6)的步骤中,内存压力会增加,直到在时代结束时它得到额外的提升并且Colab会话崩溃为止。
我广泛浏览了与fit_generator
内存消耗相关的解决方案,但是与上述设置相比,它们似乎都具有更多的复杂性。
关于我要去哪里的任何想法吗?
PS我正在使用python3
会话进行GPU加速,而tf.__version__
是1.15.0-rc3
。