我有火车和有效的数据集,每个文件夹包含493个类别,每个火车类别包含30个图像,每个有效类别包含20个图像。
在编译过程中运行代码时,应生成 火车:493 * 30 = 14790 有效:493 * 20 = 9860
但是它会生成其他图像,例如 火车= 14830 有效= 9890
代码是:
import os, sys, json
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ModelCheckpoint
from keras.applications.vgg16 import VGG16
from keras.models import Model, Sequential
from keras.layers import Input, Activation, Dropout, Flatten, Dense
from keras import optimizers
nb_epoch = 20
result_dir = './results'
train_dir = '/Users/sripdeep/Desktop/Krupali/crab_vgg16/crabdata_vgg16/train'
valid_dir = '/Users/sripdeep/Desktop/Krupali/crab_vgg16/crabdata_vgg16/valid'
if not os.path.exists(result_dir):
os.mkdir(result_dir)
def save_history(history, result_file):
loss = history.history['loss']
acc = history.history['acc']
val_loss = history.history['val_loss']
val_acc = history.history['val_acc']
nb_epoch = len(acc)
with open(result_file, "w") as fp:
fp.write("epoch\tloss\tacc\tval_loss\tval_acc\n")
for i in range(nb_epoch):
fp.write("%d\t%f\t%f\t%f\t%f\n" % (i, loss[i], acc[i], val_loss[i], val_acc[i]))
if __name__ == '__main__':
h = 224
w = 224
nb_class = 493
ckpt_file = 'ckpt-weight.h5'
# model construction
input_tensor = Input(shape=(h, w, 3))
vgg16_model = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor)
top_model = Sequential()
top_model.add(Flatten(input_shape=vgg16_model.output_shape[1:]))
top_model.add(Dense(256, activation='relu'))
top_model.add(Dropout(0.5))
top_model.add(Dense(nb_class, activation='softmax'))
model = Model(input=vgg16_model.input, output=top_model(vgg16_model.output))
#--- set the first 25 layers (up to the last conv block)
#--- to non-trainable (weights will not be updated)
#for layer in model.layers[:25]:
# layer.trainable = False
#model.load_weights(os.path.join(result_dir, ckpt_file))
model.summary()
model.compile(loss='categorical_crossentropy',
optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),
metrics=['accuracy'])
train_datagen = ImageDataGenerator(
rescale=1.0 / 255,
rotation_range=20,
width_shift_range=0.1,
height_shift_range=0.1,
#zoom_range=0.2,
#vertical_flip=True,
#horizontal_flip=True,
#channel_shift_range=0.2,
#shear_range=0.1
)
test_datagen = ImageDataGenerator(
rescale=1.0 / 255
)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(h, w),
batch_size=32,
class_mode='categorical')
#class_mode='binary')
validation_generator = test_datagen.flow_from_directory(
valid_dir,
target_size=(h, w),
batch_size=32,
class_mode='categorical')
#class_mode='binary')
print(train_generator.class_indices)
with open('class_indices.json', 'w') as f:
json.dump(train_generator.class_indices, f,
indent=4,
sort_keys=True)
# training
ckpt = ModelCheckpoint(filepath=ckpt_file, verbose=1, save_best_only=True)
history = model.fit_generator(
train_generator,
samples_per_epoch=100,
nb_epoch=nb_epoch,
validation_data=validation_generator,
nb_val_samples=50,
callbacks=[ckpt])
# save resutls
model.save_weights(os.path.join(result_dir, 'ckpt-weight-last.h5'))
save_history(history, os.path.join(result_dir, 'ckpt-history.txt'))
它生成(上面的代码运行):
找到了148个属于493个类别的图像。 找到属于493个类别的9890个图像。
我应该生成:
找到了属于493个类别的 14790 个图像。 找到了属于493个类别的 9860 个图像。