Keras model.fit_generator引发Stopiteration错误

时间:2018-10-30 11:59:36

标签: python-3.x tensorflow keras

我试图从此github存储库https://github.com/pierluigiferrari/ssd_keras中的ssd300_training.ipynb代码在Win10系统上运行SSD的端口。 经过少量的研究,在训练模型和解决方案时遇到“ StopIteration”错误,建议在数据生成器函数中添加一会儿True块。但是,我对代码的熟练程度不如它不像函数格式那样。 这是数据生成器代码:

# 6: Create the generator handles that will be passed to Keras' `fit_generator()` function.
train_generator = train_dataset.generate(batch_size=batch_size,
                                     shuffle=True,
                                     transformations=[ssd_data_augmentation],
                                     label_encoder=ssd_input_encoder,
                                     returns={'processed_images',
                                              'encoded_labels'},
                                     keep_images_without_gt=False)

val_generator = val_dataset.generate(batch_size=batch_size,
                                 shuffle=False,
                                 transformations=[convert_to_3_channels,
                                                  resize],
                                 label_encoder=ssd_input_encoder,
                                 returns={'processed_images',
                                          'encoded_labels'},
                                 keep_images_without_gt=False)  

请问我该如何添加一会儿True:循环,这样它可以停止发出StopIteration错误。这是训练代码及其给出的错误:

initial_epoch   = 0
final_epoch     = 50
steps_per_epoch = 10



history = model.fit_generator(generator=train_generator,
                          steps_per_epoch=steps_per_epoch,
                          epochs=final_epoch,
                          callbacks=callbacks,
                          validation_data=val_generator,
                          validation_steps=ceil(val_dataset_size/batch_size),
                          initial_epoch=initial_epoch
                         )

错误:

    Epoch 1/50
 9/10 [==========================>...] - ETA: 24s - loss: 3.45 - ETA: 13s - loss: 3.45 - ETA: 9s - loss: 3.4506 - ETA: 6s - loss: 3.450 - ETA: 5s - loss: 3.450 - ETA: 3s - loss: 3.450 - ETA: 2s - loss: 3.450 - ETA: 1s - loss: 3.450 - ETA: 0s - loss: 3.4504
---------------------------------------------------------------------------
StopIteration                             Traceback (most recent call last)
<ipython-input-16-484d98ebb6b5> in <module>()
     11                               #callbacks=callbacks,
     12                               validation_data=val_generator,
---> 13                               validation_steps=ceil(val_dataset_size/batch_size),
     14                               #initial_epoch=initial_epoch
     15                              )

c:\users\keboc\anaconda3\envs\tensorflow_1.8\lib\site-packages\keras\legacy\interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name +
     90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

c:\users\keboc\anaconda3\envs\tensorflow_1.8\lib\site-packages\keras\engine\training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1413             use_multiprocessing=use_multiprocessing,
   1414             shuffle=shuffle,
-> 1415             initial_epoch=initial_epoch)
   1416 
   1417     @interfaces.legacy_generator_methods_support

c:\users\keboc\anaconda3\envs\tensorflow_1.8\lib\site-packages\keras\engine\training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
    228                             val_enqueuer_gen,
    229                             validation_steps,
--> 230                             workers=0)
    231                     else:
    232                         # No need for try/except because

c:\users\keboc\anaconda3\envs\tensorflow_1.8\lib\site-packages\keras\legacy\interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name +
     90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

c:\users\keboc\anaconda3\envs\tensorflow_1.8\lib\site-packages\keras\engine\training.py in evaluate_generator(self, generator, steps, max_queue_size, workers, use_multiprocessing, verbose)
   1467             workers=workers,
   1468             use_multiprocessing=use_multiprocessing,
-> 1469             verbose=verbose)
   1470 
   1471     @interfaces.legacy_generator_methods_support

c:\users\keboc\anaconda3\envs\tensorflow_1.8\lib\site-packages\keras\engine\training_generator.py in evaluate_generator(model, generator, steps, max_queue_size, workers, use_multiprocessing, verbose)
    325 
    326         while steps_done < steps:
--> 327             generator_output = next(output_generator)
    328             if not hasattr(generator_output, '__len__'):
    329                 raise ValueError('Output of generator should be a tuple '

c:\users\keboc\anaconda3\envs\tensorflow_1.8\lib\site-packages\keras\utils\data_utils.py in get(self)
    783                 all_finished = all([not thread.is_alive() for thread in self._threads])
    784                 if all_finished and self.queue.empty():
--> 785                     raise StopIteration()
    786                 else:
    787                     time.sleep(self.wait_time)

StopIteration: 

感谢您的帮助

0 个答案:

没有答案