在单个文件中保存大量的numpy数组,并使用它来适应keras模型

时间:2019-03-07 16:07:47

标签: numpy machine-learning keras hdf5 numpy-ndarray

我有大量不适合RAM的numpy数组。可以说数以百万计:

np.arange(10) 
  1. 我希望将它们逐个文件地保存在文件系统中。
  2. 我想从文件中读取它们,并使用model.fit_generator
  3. 将它们提供给我的keras模型。

我了解到dask适用于无法容纳在内存中但无法实现目标的大数据。

1 个答案:

答案 0 :(得分:0)

用泡菜将文件写入磁盘:

pickle.dump((x, y), open(file, "wb"), protocol=pickle.HIGHEST_PROTOCOL)

然后创建一个测试和训练文件列表,并创建一个生成器:

def raw_generator(files):
    while 1:      
        for file_num, file in enumerate(files):
            try:
                x, y = pickle.load(open(file, 'rb'))                   
                batches = int(np.ceil(len(y) / batch_size))
                for i in range(0, batches):                        
                    end = min(len(x), i * batch_size + batch_size)
                    yield x[i * batch_size:end], y[i * batch_size:end]

            except EOFError:
                print("error" + file)

train_gen = preprocessing.generator(training_files)
test_gen = preprocessing.generator(test_files)

最后调用fit_generator:

history = model.fit_generator(
                generator=train_gen,
                steps_per_epoch= (len(training_files)*data_per_file)/batch_size,
                epochs=epochs
                validation_data=test_gen,
                validation_steps=(len(test_files)*data_per_file)/batch_size,        
                use_multiprocessing=False,
                max_queue_size=10,
                workers=1,
                verbose=1)