我们是否可以为ImageDataGenerator添加一些功能,以便ImageDataGenerator可以获取文件名列表,以及每个小批量的随机样本图像?
我知道我可以自定义一个继承ImageDataGenerator类的类,但我仍然不知道如何做到这一点。
这就是我所做的:
for epoch in range(epochs):
print("epoch is: %d, total epochs: %f" % ((epoch+1), int(epochs)))
print("prepare training batch...")
train_batch = makebatch(filelist=self.train_files, img_num=img_num, slice_times=slice_times)
print("prepare validation batch..")
val_batch = makebatch(filelist=self.val_files, img_num=int(math.ceil(img_num*0.2)), slice_times=slice_times)
x_train = train_batch
y_train = x_train
x_val = val_batch
y_val = x_val
print("generate training data...")
train_datagen.fit(x_train)
train_generator = train_datagen.flow(
x=x_train,
y=y_train,
batch_size=16)
val_datagen.fit(x_val)
val_generator = val_datagen.flow(
x=x_val,
y=y_val,
batch_size=16)
print("start training..")
history = model.fit_generator(
generator=train_generator,
steps_per_epoch=None,
epochs=1,
verbose=1,
validation_data=val_generator,
validation_steps=None,
callbacks=self.callbacks)
我真正想要获得的是我可以删除每个批次的for循环和生成器随机样本图像。
有人可以提供帮助吗?
答案 0 :(得分:1)
在这里,我会做什么。
假设我有一个存储在变量X_train,X_validation中的所有图像的路径列表,并且标签存储为y_train和y_validation。
首先,我将定义一个序列生成器。 (这是来自keras网站)
from skimage.io import imread
from skimage.transform import resize
import numpy as np
# Here, `x_set` is list of path to the images
# and `y_set` are the associated classes.
class CIFAR10Sequence(Sequence):
def __init__(self, x_set, y_set, batch_size):
self.x, self.y = x_set, y_set
self.batch_size = batch_size
def __len__(self):
return int(np.ceil(len(self.x) / float(self.batch_size)))
def __getitem__(self, idx):
batch_x = self.x[idx * self.batch_size:(idx + 1) * self.batch_size]
batch_y = self.y[idx * self.batch_size:(idx + 1) * self.batch_size]
return np.array([
resize(imread(file_name), (200, 200))
for file_name in batch_x]), np.array(batch_y)
现在,我将定义生成器用于培训和验证
Xtrain_gen = detracSequence(X_train,y_train,batch_size=512) # you can choose your batch size.
Xvalidation_gen = detracSequence(X_validation,y_validation,batch_size=512)
现在,培训模型的最后一步
model.fit_generator(generator=Xtrain_gen, epochs=100, validation_data=Xvalidation_gen,use_multiprocessing=True)
这将为您避免for循环,并且它非常有效,因为CPU并行获取数据。