Keras fit_generator抛出ValueError

时间:2017-02-24 20:30:03

标签: generator keras

因此,我尝试创建一个生成器来迭代数据集,以便在Keras的fit_generator中进行训练。这里是生成器,模型和fit_generator调用的定义:

import numpy as np
from queue import Queue, deque
from keras.models import Sequential
from keras.layers import Dense

num_features = 40
len_data = 100
data = np.random.rand(len_data, num_features)

def train_generator(train_idxs):
    while True:
        i = train_idxs.get(block=False)
        training_example = data[i,:]
        training_example.shape = (1, len(training_example))

        yield (training_example, training_example)


layer0_size = num_features
layer1_size = layer0_size / 2
layer2_size = layer1_size / 2

layers = []
layers.append(
    Dense(input_dim=layer0_size, output_dim=layer1_size, activation='relu'))
layers.append(
    Dense(input_dim=layer1_size, output_dim=layer2_size, activation='relu'))
layers.append(
    Dense(input_dim=layer2_size, output_dim=layer1_size, activation='relu'))
layers.append(
    Dense(input_dim=layer1_size, output_dim=layer0_size, activation='sigmoid'))

model = Sequential()
for layer in layers:
    model.add(layer)

model.compile(optimizer='adam', loss='binary_crossentropy')

train_idxs = Queue()
train_idxs.queue = deque(range(len_data))
train_gen = train_generator(train_idxs)
max_q_size = 2
model.fit_generator(train_gen, samples_per_epoch=len(data), max_q_size=max_q_size, nb_epoch=1)

然后Keras将成功训练98/100训练样例并抛出此错误

 98/100 [============================>.] - ETA: 0s - loss: 0.6930Exception in thread Thread-1:
Traceback (most recent call last):
  File "/usr/lib/python3.5/threading.py", line 914, in _bootstrap_inner
    self.run()
  File "/usr/lib/python3.5/threading.py", line 862, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 429, in data_generator_task
    generator_output = next(self._generator)
  File "scrap.py", line 12, in train_generator
    i = train_idxs.get(block=False)
  File "/usr/lib/python3.5/queue.py", line 161, in get
    raise Empty
queue.Empty

Traceback (most recent call last):
  File "scrap.py", line 43, in <module>
    model.fit_generator(train_gen, samples_per_epoch=len(data), max_q_size=max_q_size, nb_epoch=1)
  File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 935, in fit_generator
    initial_epoch=initial_epoch)
  File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 1528, in fit_generator
    str(generator_output))
ValueError: output of generator should be a tuple (x, y, sample_weight) or (x, y). Found: None

似乎正在发生的事情是,它突然出现了所有的training_idxs,并且在Keras用尽其内部队列中的训练样例之前,它仍然试图获得更多。有没有办法让它停止尝试从发电机获得更多的训练样例?

0 个答案:

没有答案