我在一台P100上训练我的模型,并具有32 GB的可用内存,最多可用于16个内核。
class DataGenerator(tf.keras.utils.Sequence):
def __init__(self, images, input_images=5, predict_images=5, batch_size=16, image_size=(200, 200),
channels=1):
self.images = images
self.input_images = input_images
self.predict_images = predict_images
self.batch_size = batch_size
self.image_size = image_size
self.channels = channels
self.nr_images = int(len(self.images)-input_images-predict_images)
def __len__(self):
return int(np.floor(self.nr_images) / self.batch_size)
def __getitem__(self, item):
# Randomly select the beginning image of each batch
batch_indices = random.sample(range(0, self.nr_images), self.batch_size)
# Allocate the output images
x = np.empty((self.batch_size, self.input_images,
*self.image_size, self.channels), dtype='uint8')
y = np.empty((self.batch_size, self.predict_images,
*self.image_size, self.channels), dtype='uint8')
# Get the list of input an prediction images
for i in range(self.batch_size):
list_images_input = range(batch_indices[i], batch_indices[i]+self.input_images)
list_images_predict = range(batch_indices[i]+self.input_images,
batch_indices[i]+self.input_images+self.predict_images)
for j, ID in enumerate(list_images_input):
x[i, ] = np.load(np.reshape(self.images[ID], (*self.imagesize, self.channels))
# Read in the prediction images
for j, ID in enumerate(list_images_predict):
y[i, ] = np.load(np.reshape(self.images[ID], (*self.imagesize, self.channels))
return x, y
# Training the model using fit_generator
params = {'batch_size': 8,
'input_images': 5,
'predict_images': 5,
'image_size': (100, 100),
'channels': 1
}
data_path = "input_frames/"
input_images = sorted(glob.glob(data_path + "*.png"))
training_generator = DataGenerator(input_images, **params)
model.fit_generator(generator=training_generator, epochs=10, workers=6)
我希望Keras将在GPU上处理当前批处理时准备下一个数据批处理,但似乎并没有赶上。换句话说,在将数据发送到GPU之前准备数据似乎是瓶颈。
关于如何改善这样的数据生成器性能的任何想法?是否缺少某些保证可以及时准备数据的东西?
非常感谢!
答案 0 :(得分:0)
使用fit_generator时,有一个worker =设置可用于扩大生成器worker的数量。但是,您应确保考虑到 getitem 中的'item'参数,以确保不同的工作程序(未同步)根据项目索引返回不同的值。即,不是随机抽样,而是可能只是根据索引返回一部分数据。您可以在开始之前对整个数据集进行随机播放,以确保数据集顺序是随机的。
答案 1 :(得分:0)
您可以尝试使用use_multiprocessing = True吗?这些是我在使用提供的数据生成器的基于GTX 1080Ti的系统上观察到的数字。
model.fit_generator(generator=training_generator, epochs=10, workers=6)
148/148 [=============================]- 9s 60ms / step
model.fit_generator(generator=training_generator, epochs=10, workers=6, use_multiprocessing=True)
148/148 [=============================]- 2s 11ms / step