如何在不使用model.fit_generator的情况下将旋转应用于Keras中的图像?

时间:2018-02-16 02:47:03

标签: python machine-learning deep-learning keras conv-neural-network

我正在研究使用卷积神经网络的图像像素分类问题。 我的培训images的大小为128x128x3,大小为mask 标签128x128Xtrain, Xvalid, ytrain, yvalid = train_test_split(images, masks,test_size=0.3, random_state=567) model.fit(Xtrain, ytrain, batch_size=32, epochs=20, verbose=1, shuffle=True, validation_data=(Xvalid, yvalid))

我在Keras接受如下培训:

Xtrain

但是,我想将随机2D旋转应用于ytrain128x128x3,其大小分别为128x128model.fit。更具体地说,我想在每个纪元迭代中应用此旋转。

目前,我想继续使用model.fit_generator而不使用.fit_generator,因为我知道数据扩充通常是使用model.fit完成的。

基本上,我想循环Xtrain,以便ytrainself.ignCB = wx.CheckBox(self, 0, label='Ignore Errors') self.ignCB.Bind(wx.EVT_MOTION, self.on_mouse_over) def on_mouse_over(self, event): self.ignCB.SetToolTipString('Continue on Download Errors, Skips Unavailable Videos within a Playlist') 随机旋转每个纪元。我是Python和Keras的新手,所以如果可能的话,欢迎任何见解。

1 个答案:

答案 0 :(得分:1)

以下是使用ImageDataGenerator将输出保存到指定目录的示例,从而避免了使用model.fit_generator的要求。

from keras.preprocessing.image import ImageDataGenerator, array_to_img, img_to_array, load_img

datagen = ImageDataGenerator(
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')

img = load_img('data/train/cats/cat.0.jpg')  # this is a PIL image
x = img_to_array(img)  # this is a Numpy array with shape (3, 150, 150)
x = x.reshape((1,) + x.shape)  # this is a Numpy array with shape (1, 3, 150, 150)

# the .flow() command below generates batches of randomly transformed images
# and saves the results to the `preview/` directory
i = 0
for batch in datagen.flow(x, batch_size=1,
                          save_to_dir='preview', save_prefix='cat', save_format='jpeg'):
    i += 1
    if i > 20:
        break  # otherwise the generator would loop indefinitely

从这里采取:https://blog.keras.io/building-powerful-image-classification-models-using-very-little-data.html

您可以更改args以适合您的用例,然后生成X_train和X_valid或任何数据集,然后加载到内存中并使用普通的旧model.fit。