使用Keras model.fit_generator

时间:2017-09-29 16:43:30

标签: python-3.x iterator keras generator

在编写用于训练Keras模型的自定义生成器时,我最初尝试使用generator语法。所以我yield来自__next__。但是,当我尝试用model.fit_generator训练我的模式时,我会得到一个错误,我的生成器不是迭代器。修复方法是将yield更改为return,这也需要重新调整__next__的逻辑以跟踪状态。与让yield为我工作相比,这非常麻烦。

有没有办法可以让yield使用?如果我必须使用return语句,我将需要编写几个迭代器,这些迭代器必须具有非常笨重的逻辑。

4 个答案:

答案 0 :(得分:14)

我无法帮助调试您的代码,因为您没有发布它,但我缩写了我为语义分段项目编写的自定义数据生成器,供您用作模板:

def generate_data(directory, batch_size):
    """Replaces Keras' native ImageDataGenerator."""
    i = 0
    file_list = os.listdir(directory)
    while True:
        image_batch = []
        for b in range(batch_size):
            if i == len(file_list):
                i = 0
                random.shuffle(file_list)
            sample = file_list[i]
            i += 1
            image = cv2.resize(cv2.imread(sample[0]), INPUT_SHAPE)
            image_batch.append((image.astype(float) - 128) / 128)

        yield np.array(image_batch)

用法:

model.fit_generator(
    generate_data('~/my_data', batch_size),
    steps_per_epoch=len(os.listdir('~/my_data')) // batch_size)

答案 1 :(得分:6)

我最近玩过Keras的发电机,我终于设法准备了一个例子。它使用随机数据,所以尝试在它上面教NN是没有意义的,但它是为Keras使用python生成器的一个很好的例证。

生成一些数据

import numpy as np
import pandas as pd
data = np.random.rand(200,2)
expected = np.random.randint(2, size=200).reshape(-1,1)

dataFrame = pd.DataFrame(data, columns = ['a','b'])
expectedFrame = pd.DataFrame(expected, columns = ['expected'])

dataFrameTrain, dataFrameTest = dataFrame[:100],dataFrame[-100:]
expectedFrameTrain, expectedFrameTest = expectedFrame[:100],expectedFrame[-100:]

发电机

def generator(X_data, y_data, batch_size):

  samples_per_epoch = X_data.shape[0]
  number_of_batches = samples_per_epoch/batch_size
  counter=0

  while 1:

    X_batch = np.array(X_data[batch_size*counter:batch_size*(counter+1)]).astype('float32')
    y_batch = np.array(y_data[batch_size*counter:batch_size*(counter+1)]).astype('float32')
    counter += 1
    yield X_batch,y_batch

    #restart counter to yeild data in the next epoch as well
    if counter >= number_of_batches:
        counter = 0

Keras模型

from keras.datasets import mnist
from keras.models import Sequential
from keras.layers.core import Dense, Dropout, Activation, Flatten, Reshape
from keras.layers.convolutional import Convolution1D, Convolution2D, MaxPooling2D
from keras.utils import np_utils


 model = Sequential()
 model.add(Dense(12, activation='relu', input_dim=dataFrame.shape[1]))
 model.add(Dense(1, activation='sigmoid'))


 model.compile(loss='binary_crossentropy', optimizer='adadelta', metrics=['accuracy'])

 #Train the model using generator vs using the full batch
 batch_size = 8

 model.fit_generator(generator(dataFrameTrain,expectedFrameTrain,batch_size), epochs=3,steps_per_epoch = dataFrame.shape[0]/batch_size, validation_data=generator(dataFrameTest,expectedFrameTest,batch_size*2),validation_steps=dataFrame.shape[0]/batch_size*2)

 #without generator
 #model.fit(x = np.array(dataFrame), y = np.array(expected), batch_size = batch_size, epochs = 3)

输出

Epoch 1/3
25/25 [==============================] - 3s - loss: 0.7297 - acc: 0.4750 - 
val_loss: 0.7183 - val_acc: 0.5000
Epoch 2/3
25/25 [==============================] - 0s - loss: 0.7213 - acc: 0.3750 - 
val_loss: 0.7117 - val_acc: 0.5000
Epoch 3/3
25/25 [==============================] - 0s - loss: 0.7132 - acc: 0.3750 - 
val_loss: 0.7065 - val_acc: 0.5000

答案 2 :(得分:0)

这是我实现读取任意大小的文件的方式。而且它就像一种魅力。

<plugins>
    <plugin>
        <groupId>org.apache.maven.plugins</groupId>
        <artifactId>maven-surefire-plugin</artifactId>
    </plugin>
</plugins>

后面的我主要有

import pandas as pd

hdr=[]
for i in range(num_labels+num_features):
    hdr.append("Col-"+str(i)) # data file do not have header so I need to
                              # provide one for pd.read_csv by chunks to work

def tgen(filename):
    csvfile = open(filename)
    reader = pd.read_csv(csvfile, chunksize=batch_size,names=hdr,header=None)
    while True:
    for chunk in reader:
        W=chunk.values        # labels and features
        Y =W[:,:num_labels]   # labels 
        X =W[:,num_labels:]   # features
        X= X / 255            # any required transformation
        yield X, Y
    csvfile = open(filename)
    reader = pd.read_csv(csvfile, chunksize=batchz,names=hdr,header=None)

答案 3 :(得分:0)

我想用TensorFlow 2.x升级 Vaasha的代码,以提高培训效率并简化数据处理。这对于图像处理特别有用。

使用上例中Vaasha生成的Generator函数或使用 tf.data.dataset API 处理数据。当使用元数据处理任何数据集时,后一种方法非常有用。例如,可以使用一些语句来加载和处理MNIST数据。

import tensorflow as tf # Ensure that TensorFlow 2.x is used
tf.compat.v1.enable_eager_execution()
import tensorflow_datasets as tfds # Needed if you are using any of the tf datasets such as MNIST, CIFAR10
mnist_train = tfds.load(name="mnist", split="train")

使用tfds.load数据集。一旦数据被加载和处理(例如,转换分类变量,调整大小等)。

现在使用TensorFlow 2.x升级keras模型

 model = tf.keras.Sequential() # Tensorflow 2.0 upgrade
 model.add(tf.keras.layers.Dense(12, activation='relu', input_dim=dataFrame.shape[1]))
 model.add(tf.keras.layers.Dense(1, activation='sigmoid'))


 model.compile(loss='binary_crossentropy', optimizer='adadelta', metrics=['accuracy'])

 #Train the model using generator vs using the full batch
 batch_size = 8

 model.fit_generator(generator(dataFrameTrain,expectedFrameTrain,batch_size), epochs=3,steps_per_epoch = dataFrame.shape[0]/batch_size, validation_data=generator(dataFrameTest,expectedFrameTest,batch_size*2),validation_steps=dataFrame.shape[0]/batch_size*2)

这将升级模型以在TensorFlow 2.x中运行