我有下一个代码:
from sklearn.model_selection import train_test_split
from scipy.misc import imresize
def _chunks(l, n):
"""Yield successive n-sized chunks from l."""
for i in range(0, len(l), n):
yield l[i:i + n]
def _batch_generator(data, batch_size):
indexes = range(len(data))
index_chunks = _chunks(indexes, batch_size)
for i, indexes in enumerate(index_chunks):
print("\nLoaded batch {0}\n".format(i + 1))
batch_X = []
batch_y = []
for index in indexes:
record = data[index]
image = _read_train_image(record["id"], record["index"])
mask = _read_train_mask(record["id"], record["index"])
mask_resized = imresize(mask, (1276, 1916)) >= 123
mask_reshaped = mask_resized.reshape((1276, 1916, 1))
batch_X.append(image)
batch_y.append(mask_reshaped)
np_batch_X = np.array(batch_X)
np_batch_y = np.array(batch_y)
yield np_batch_X, np_batch_y
def train(data, model, batch_size, epochs):
train_data, test_data = train_test_split(data)
samples_per_epoch = len(train_data)
steps_per_epoch = samples_per_epoch // batch_size
print("Train on {0} records ({1} batches)".format(samples_per_epoch, steps_per_epoch))
train_generator = _batch_generator(train_data, batch_size)
model.fit_generator(train_generator,
steps_per_epoch=steps_per_epoch,
nb_epoch=epochs,
verbose=1)
train(train_indexes[:30], autoencoder,
batch_size=2,
epochs=1)
所以看起来它必须是下一个方式:
len(list(_batch_generator(train_indexes[:22], 2)))
确实返回11 steps_per_epoch=steps_per_epoch
)nb_epochs=epochs
,epochs=1
)但输出有下一个观点:
Train on 22 records (11 batches)
Epoch 1/1
Loaded batch 1
C:\Users\user\venv\machinelearning\lib\site-packages\ipykernel_launcher.py:39: UserWarning: The semantics of the Keras 2 argument `steps_per_epoch` is not the same as the Keras 1 argument `samples_per_epoch`. `steps_per_epoch` is the number of batches to draw from the generator at each epoch. Basically steps_per_epoch = samples_per_epoch/batch_size. Similarly `nb_val_samples`->`validation_steps` and `val_samples`->`steps` arguments have changed. Update your method calls accordingly.
C:\Users\user\venv\machinelearning\lib\site-packages\ipykernel_launcher.py:39: UserWarning: Update your `fit_generator` call to the Keras 2 API: `fit_generator(<generator..., steps_per_epoch=11, verbose=1, epochs=1)`
Loaded batch 2
1/11 [=>............................] - ETA: 11s - loss: 0.7471
Loaded batch 3
Loaded batch 4
Loaded batch 5
Loaded batch 6
2/11 [====>.........................] - ETA: 17s - loss: 0.7116
Loaded batch 7
Loaded batch 8
Loaded batch 9
Loaded batch 10
3/11 [=======>......................] - ETA: 18s - loss: 0.6931
Loaded batch 11
Exception in thread Thread-50:
Traceback (most recent call last):
File "C:\Anaconda3\Lib\threading.py", line 916, in _bootstrap_inner
self.run()
File "C:\Anaconda3\Lib\threading.py", line 864, in run
self._target(*self._args, **self._kwargs)
File "C:\Users\user\venv\machinelearning\lib\site-packages\keras\utils\data_utils.py", line 560, in data_generator_task
generator_output = next(self._generator)
StopIteration
4/11 [=========>....................] - ETA: 18s - loss: 0.6663
---------------------------------------------------------------------------
StopIteration Traceback (most recent call last)
<ipython-input-16-092ba6eb51d2> in <module>()
1 train(train_indexes[:30], autoencoder,
2 batch_size=2,
----> 3 epochs=1)
<ipython-input-15-f2fec4e53382> in train(data, model, batch_size, epochs)
37 steps_per_epoch=steps_per_epoch,
38 nb_epoch=epochs,
---> 39 verbose=1)
C:\Users\user\venv\machinelearning\lib\site-packages\keras\legacy\interfaces.py in wrapper(*args, **kwargs)
85 warnings.warn('Update your `' + object_name +
86 '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 87 return func(*args, **kwargs)
88 wrapper._original_function = func
89 return wrapper
C:\Users\user\venv\machinelearning\lib\site-packages\keras\engine\training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, initial_epoch)
1807 batch_index = 0
1808 while steps_done < steps_per_epoch:
-> 1809 generator_output = next(output_generator)
1810
1811 if not hasattr(generator_output, '__len__'):
StopIteration:
正如我所看到的 - 所有批次都是成功的(参见&#34;已加载的批次&#34;)
但是在处理第1期的第3批时,keras会引发StopIteration。
答案 0 :(得分:3)
我也遇到了这个问题,我发现一种方法是可以在数据生成器函数中插入“ while True”块。但我无法获得消息来源。您可以参考以下代码:
while True:
assert len(inputs) == len(targets)
indices = np.arange(len(inputs))
if shuffle:
np.random.shuffle(indices)
if batchsize > len(indices):
sys.stderr.write('BatchSize out of index size')
batchsize = len(indices)
for start_idx in range(0, len(inputs) - batchsize + 1, batchsize):
if shuffle:
excerpt = indices[start_idx:start_idx + batchsize]
else:
excerpt = slice(start_idx, start_idx + batchsize)
yield inputs[excerpt], targets[excerpt]
答案 1 :(得分:1)
我找到了问题来源。 首先 - 我的数据集在匹配结束前完全重新加载,因此它会提升
Exception in thread Thread-50:
Traceback (most recent call last):
File "C:\Anaconda3\Lib\threading.py", line 916, in _bootstrap_inner
self.run()
File "C:\Anaconda3\Lib\threading.py", line 864, in run
self._target(*self._args, **self._kwargs)
File "C:\Users\user\venv\machinelearning\lib\site-packages\keras\utils\data_utils.py", line 560, in data_generator_task
generator_output = next(self._generator)
StopIteration
异常处理程序设置stop_event并重新引用异常
但是:
def get(self):
"""Creates a generator to extract data from the queue.
Skip the data if it is `None`.
# Returns
A generator
"""
while self.is_running():
if not self.queue.empty():
inputs = self.queue.get()
if inputs is not None:
yield inputs
else:
time.sleep(self.wait_time)
因此,当设置停止事件时 - 它可以从队列中加载数据
所以我将max_queue_size限制为1。
答案 2 :(得分:1)
关于此问题的注释,以防其他人来此页面追逐。 StopIteration错误是keras中的一个已知问题,可以通过确保将批大小设置为样本数量的整数倍来解决,在某些情况下可以解决。如果这不能解决问题,我发现的一件事是,具有数据生成器无法读取的时髦文件格式有时还会导致stopIteration错误。为了解决这个问题,我在训练文件夹上运行了一个脚本,该脚本在训练之前将所有图像转换为标准文件类型(jpg或png)。看起来像这样。
import glob
from PIL import Image
import os
d=1
for sample in glob.glob(r'C:\Users\Jeremiah\Pictures\training\classLabel_unformatted\*'):
im = Image.open(sample)
im.save(r'C:\Users\Jeremiah\Pictures\training\classLabel_formatted\%s.png' %d)
d=d+1
我发现运行此脚本或类似的脚本可以大大减少我出现此类错误的频率,尤其是当我的训练数据来自诸如Google图片搜索之类的地方时。