运行VGG16模型时数据集大小增加

时间:2019-06-28 04:12:45

标签: python deep-learning vgg-net

我有火车和有效的数据集,每个文件夹包含493个类别,每个火车类别包含30个图像,每个有效类别包含20个图像。

在编译过程中运行代码时,应生成 火车:493 * 30 = 14790 有效:493 * 20 = 9860

但是它会生成其他图像,例如 火车= 14830 有效= 9890

代码是:

import os, sys, json
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ModelCheckpoint
from keras.applications.vgg16 import VGG16
from keras.models import Model, Sequential
from keras.layers import Input, Activation, Dropout, Flatten, Dense
from keras import optimizers

nb_epoch = 20

result_dir = './results'
train_dir = '/Users/sripdeep/Desktop/Krupali/crab_vgg16/crabdata_vgg16/train'
valid_dir = '/Users/sripdeep/Desktop/Krupali/crab_vgg16/crabdata_vgg16/valid'
if not os.path.exists(result_dir):
    os.mkdir(result_dir)

def save_history(history, result_file):
    loss = history.history['loss']
    acc = history.history['acc']
    val_loss = history.history['val_loss']
    val_acc = history.history['val_acc']
    nb_epoch = len(acc)

    with open(result_file, "w") as fp:
        fp.write("epoch\tloss\tacc\tval_loss\tval_acc\n")
        for i in range(nb_epoch):
            fp.write("%d\t%f\t%f\t%f\t%f\n" % (i, loss[i], acc[i], val_loss[i], val_acc[i]))

if __name__ == '__main__':

    h = 224 
    w = 224 
    nb_class = 493 
    ckpt_file = 'ckpt-weight.h5'

    # model construction
    input_tensor = Input(shape=(h, w, 3))
    vgg16_model = VGG16(include_top=False, weights='imagenet', input_tensor=input_tensor)

    top_model = Sequential()
    top_model.add(Flatten(input_shape=vgg16_model.output_shape[1:]))
    top_model.add(Dense(256, activation='relu'))
    top_model.add(Dropout(0.5))
    top_model.add(Dense(nb_class, activation='softmax'))

    model = Model(input=vgg16_model.input, output=top_model(vgg16_model.output))

    #--- set the first 25 layers (up to the last conv block)
    #--- to non-trainable (weights will not be updated)
    #for layer in model.layers[:25]:
    #    layer.trainable = False

    #model.load_weights(os.path.join(result_dir, ckpt_file))
    model.summary()

    model.compile(loss='categorical_crossentropy',
              optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),
              metrics=['accuracy'])



    train_datagen = ImageDataGenerator(
        rescale=1.0 / 255,
        rotation_range=20,
        width_shift_range=0.1,
        height_shift_range=0.1,
        #zoom_range=0.2,
        #vertical_flip=True,
        #horizontal_flip=True,
        #channel_shift_range=0.2,
        #shear_range=0.1
    )

    test_datagen = ImageDataGenerator(
        rescale=1.0 / 255
    )

    train_generator = train_datagen.flow_from_directory(
        train_dir,
        target_size=(h, w),
        batch_size=32,
        class_mode='categorical')
        #class_mode='binary')

    validation_generator = test_datagen.flow_from_directory(
        valid_dir,
        target_size=(h, w),
        batch_size=32,
        class_mode='categorical')
        #class_mode='binary')

    print(train_generator.class_indices)
    with open('class_indices.json', 'w') as f:
        json.dump(train_generator.class_indices, f,
             indent=4,
             sort_keys=True)

    # training
    ckpt = ModelCheckpoint(filepath=ckpt_file, verbose=1, save_best_only=True)

    history = model.fit_generator(
        train_generator,
        samples_per_epoch=100,
        nb_epoch=nb_epoch,
        validation_data=validation_generator,
        nb_val_samples=50,
        callbacks=[ckpt])

    # save resutls
    model.save_weights(os.path.join(result_dir, 'ckpt-weight-last.h5'))
    save_history(history, os.path.join(result_dir, 'ckpt-history.txt'))

它生成(上面的代码运行):

找到了148个属于493个类别的图像。 找到属于493个类别的9890个图像。

我应该生成:

找到了属于493个类别的 14790 个图像。 找到了属于493个类别的 9860 个图像。

0 个答案:

没有答案