因此,我尝试创建一个生成器来迭代数据集,以便在Keras的fit_generator中进行训练。这里是生成器,模型和fit_generator调用的定义:
import numpy as np
from queue import Queue, deque
from keras.models import Sequential
from keras.layers import Dense
num_features = 40
len_data = 100
data = np.random.rand(len_data, num_features)
def train_generator(train_idxs):
while True:
i = train_idxs.get(block=False)
training_example = data[i,:]
training_example.shape = (1, len(training_example))
yield (training_example, training_example)
layer0_size = num_features
layer1_size = layer0_size / 2
layer2_size = layer1_size / 2
layers = []
layers.append(
Dense(input_dim=layer0_size, output_dim=layer1_size, activation='relu'))
layers.append(
Dense(input_dim=layer1_size, output_dim=layer2_size, activation='relu'))
layers.append(
Dense(input_dim=layer2_size, output_dim=layer1_size, activation='relu'))
layers.append(
Dense(input_dim=layer1_size, output_dim=layer0_size, activation='sigmoid'))
model = Sequential()
for layer in layers:
model.add(layer)
model.compile(optimizer='adam', loss='binary_crossentropy')
train_idxs = Queue()
train_idxs.queue = deque(range(len_data))
train_gen = train_generator(train_idxs)
max_q_size = 2
model.fit_generator(train_gen, samples_per_epoch=len(data), max_q_size=max_q_size, nb_epoch=1)
然后Keras将成功训练98/100训练样例并抛出此错误
98/100 [============================>.] - ETA: 0s - loss: 0.6930Exception in thread Thread-1:
Traceback (most recent call last):
File "/usr/lib/python3.5/threading.py", line 914, in _bootstrap_inner
self.run()
File "/usr/lib/python3.5/threading.py", line 862, in run
self._target(*self._args, **self._kwargs)
File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 429, in data_generator_task
generator_output = next(self._generator)
File "scrap.py", line 12, in train_generator
i = train_idxs.get(block=False)
File "/usr/lib/python3.5/queue.py", line 161, in get
raise Empty
queue.Empty
Traceback (most recent call last):
File "scrap.py", line 43, in <module>
model.fit_generator(train_gen, samples_per_epoch=len(data), max_q_size=max_q_size, nb_epoch=1)
File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 935, in fit_generator
initial_epoch=initial_epoch)
File "/usr/local/lib/python3.5/dist-packages/keras/engine/training.py", line 1528, in fit_generator
str(generator_output))
ValueError: output of generator should be a tuple (x, y, sample_weight) or (x, y). Found: None
似乎正在发生的事情是,它突然出现了所有的training_idxs,并且在Keras用尽其内部队列中的训练样例之前,它仍然试图获得更多。有没有办法让它停止尝试从发电机获得更多的训练样例?