我从事一些我想创建神经网络的机器学习项目,但是它还不够成功。所以我决定使用daa增强。但是问题来了-我不确定如何将其与我开始研究的架构一起使用。 我的文件夹中的数据集具有以下结构:
<ul>
<li>
<ul>train
<li>class 1</li>
<li>class 2</li>
<li>class 3</li>
<li>class 4</li>
</ul>
</li>
<li><ul> validation
<li>class 1</li>
<li>class 2</li>
<li>class 3</li>
<li>class 4</li>
</ul>
</li></ul>
所有在“数据”文件夹中。主文件夹指示它分别存储用于训练或验证的数据,子文件夹存储特定类别的图像。 (该类表示图像上的第1、2或3个对象(前三个可能是不同的,这意味着:同一幅图像可以延续由两个或什至3个类定义的事物),第4类表示“不存在类型为1的对象,图片上的2或3”) 我将这些文件夹的本地化存储在两个变量中:“ train_dir”和“ validation_dir”(通过使用os库获得)
我通过以下方式准备数据和生成器:
train_image_generator = ImageDataGenerator(
rotation_range=20,
width_shift_range=0.1,
height_shift_range=0.1,
shear_range=10,
brightness_range= (0.2, 0.8),
horizontal_flip=True,
vertical_flip=True,
fill_mode='nearest',
rescale=1./255 #before augmentation ussage that was the only line
)
validation_image_generator = ImageDataGenerator(rescale=1./255)
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
directory=train_dir,
shuffle=True,
target_size=(IMG_HEIGHT, IMG_WIDTH),
class_mode='binary')
val_data_gen = validation_image_generator.flow_from_directory(batch_size=batch_size,
directory=validation_dir,
target_size=(IMG_HEIGHT, IMG_WIDTH),
class_mode='binary')
然后我准备模型:
model = Sequential([
Conv2D(32, 3, padding='same', activation='relu', input_shape=(IMG_HEIGHT, IMG_WIDTH ,3)),
MaxPooling2D(),
Conv2D(64, 3, padding='same', activation='relu'),
MaxPooling2D(),
Conv2D(128, 3, padding='same', activation='relu'),
MaxPooling2D(),
Flatten(),
Dense(64, activation='relu'),
Dense(1, activation='sigmoid')
])
model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy'])
model.summary()
然后我卡住了。我不知道如何使用model.fit_generator()。我的意思是,我应该使用什么参数。 这是我之前决定使用增强功能的条件:
history = model.fit_generator(
train_data_gen,
steps_per_epoch=total_train // batch_size, #I know that total_train is value I will have to recalculate while using augmentation.
epochs=epochs,
validation_data=val_data_gen,
validation_steps=total_val // batch_size
)
我试图用'validation_image_generator.flow(train_data_gen)'代替'train_data_gen'