Keras处理无法放入内存的大型数据集

时间:2017-08-05 19:02:53

标签: csv keras training-data large-data

我正在进行面部表情识别,我正在使用Keras。我收集了很多数据集,然后我在图像上应用了数据增强功能,我在.csv文件上保存了大约500 000张图像(像素一样)(格式与fer2013.csv相同)。 / p>

这是我使用的代码:

def Zerocenter_ZCA_whitening_Global_Contrast_Normalize(list):
    Intonumpyarray = numpy.asarray(list)
    data = Intonumpyarray.reshape(img_width,img_height)
    data2 = ZeroCenter(data)
    data3 = zca_whitening(flatten_matrix(data2)).reshape(img_width,img_height)
    data4 = global_contrast_normalize(data3)
    data5 = numpy.rot90(data4,3)
    return data5



def load_data():
    train_x = []
    train_y = []
    val_x = []
    val_y = []
    test_x = []
    test_y = []

    f = open('ALL.csv')
    csv_f = csv.reader(f)

    for row in csv_f:
        if str(row[2]) == "Training":
            temp_list_train = []

            for pixel in row[1].split():
                temp_list_train.append(int(pixel))

            data = Zerocenter_ZCA_whitening_Global_Contrast_Normalize(temp_list_train)
            train_y.append(int(row[0]))
            train_x.append(data.reshape(data_resh).tolist())

        elif str(row[2]) == "PublicTest":
            temp_list_validation = []

            for pixel in row[1].split():
                temp_list_validation.append(int(pixel))

            data = Zerocenter_ZCA_whitening_Global_Contrast_Normalize(temp_list_validation)
            val_y.append(int(row[0]))
            val_x.append(data.reshape(data_resh).tolist())

        elif str(row[2]) == "PrivateTest":
            temp_list_test = []

            for pixel in row[1].split():
                temp_list_test.append(int(pixel))

            data = Zerocenter_ZCA_whitening_Global_Contrast_Normalize(temp_list_test)
            test_y.append(int(row[0]))
            test_x.append(data.reshape(data_resh).tolist())

    return train_x, train_y, val_x, val_y, test_x, test_y

然后我加载数据并将它们提供给生成器:

Train_x, Train_y, Val_x, Val_y, Test_x, Test_y = load_data()

Train_x = numpy.asarray(Train_x) 
Train_x = Train_x.reshape(Train_x.shape[0],img_rows,img_cols)

Test_x = numpy.asarray(Test_x) 
Test_x = Test_x.reshape(Test_x.shape[0],img_rows,img_cols)

Val_x = numpy.asarray(Val_x)
Val_x = Val_x.reshape(Val_x.shape[0],img_rows,img_cols)

Train_x = Train_x.reshape(Train_x.shape[0], img_rows, img_cols, 1)
Test_x = Test_x.reshape(Test_x.shape[0], img_rows, img_cols, 1)
Val_x = Val_x.reshape(Val_x.shape[0], img_rows, img_cols, 1)

Train_x = Train_x.astype('float32')
Test_x = Test_x.astype('float32')
Val_x = Val_x.astype('float32')

Train_y = np_utils.to_categorical(Train_y, nb_classes)
Test_y = np_utils.to_categorical(Test_y, nb_classes)
Val_y = np_utils.to_categorical(Val_y, nb_classes)


datagen = ImageDataGenerator(
    featurewise_center=False,
    samplewise_center=False,
    featurewise_std_normalization=False,
    samplewise_std_normalization=False,
    zca_whitening=False,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True,
    shear_range=0.03,
    zoom_range=0.03,
    vertical_flip=False)

datagen.fit(Train_x)

model.fit_generator(datagen.flow(Train_x, Train_y,
    batch_size=batch_size),
    samples_per_epoch=Train_x.shape[0],
    nb_epoch=nb_epoch,
    validation_data=(Val_x, Val_y))

当我运行代码时,RAM使用量越来越大,直到电脑冻结(我有16 Gb)。调用loading_data()时会卡住。这个问题的任何解决方案都适合我的代码吗?

1 个答案:

答案 0 :(得分:1)

似乎是this question的副本。基本上,您必须使用fit_generator()而不是fit(),并传入一个函数,一次一批地将数据加载到模型中,而不是一次性加载。