凯拉斯| LSTM批处理生成器

时间:2019-02-21 15:50:33

标签: python keras lstm

我有一个Keras模型,该模型具有input shape = (frames, height, width, channels)并具有两个标量输出(请参阅下一个代码par)。我的模型确实使用了LSTM,这就是为什么我必须添加额外的尺寸。

height = 32
width = 64
channels = 3
frames = 2
img_shape = (height, width, channels)
input_shape = (frames, height, width, channels)

如果已按以下方式定义了批处理生成器

def generator(df, batch_size, frames_per_scene=frames_per_scene):

    ### read data frame columns 
    # inputs
    img_list = df['filename']
    # outputs
    happiness= df['happiness']
    anger = df['anger']

    # create empty arrays for input and output
    batch_img = np.zeros((batch_size, frames_per_scene) + img_shape)
    batch_label = np.zeros((batch_size, 2))

    index = 0

    while True:
        for i in range(batch_size):          
            for j in range(frames_per_scene):
                label = [happiness.iloc[index], anger.iloc[index]]
                img_name = img_list[index]

                pil_img = image.load_img(img_name)

                batch_img[i, j, :] = image.img_to_array(pil_img)
            batch_label[i] = label

            index += 1
            if index == len(img_list):
                img_list = df['filename']
                happiness = df['happiness']
                anger = df['anger']

                index = 0
        yield batch_img, batch_label

当我尝试使用以下模型方法时。

model.fit_generator(
    train_batch, train_steps, epochs=epochs, verbose=verbose, 
    callbacks=callbacks_list, validation_data=validation_batch, 
    validation_steps=val_steps)

我收到以下错误

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
<ipython-input-7-289c48f3bbf8> in <module>
----> 1 train_model(_episode=0)

<ipython-input-6-b4f2c2235a41> in train_model(_episode)
     56         TensorBoard(log_dir=path_tensorboard, histogram_freq=0, write_graph=False, write_images=False)]
     57 
---> 58     model.fit_generator(train_batch, train_steps, epochs=epochs, verbose=verbose, callbacks=callbacks_list, validation_data=validation_batch, validation_steps=val_steps)

c:\envs\lstm\lib\site-packages\keras\legacy\interfaces.py in wrapper(*args, **kwargs)
     85                 warnings.warn('Update your `' + object_name +
     86                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 87             return func(*args, **kwargs)
     88         wrapper._original_function = func
     89         return wrapper

c:\envs\lstm\lib\site-packages\keras\engine\training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   2167                                 max_queue_size=max_queue_size,
   2168                                 workers=workers,
-> 2169                                 use_multiprocessing=use_multiprocessing)
   2170                         else:
   2171                             # No need for try/except because

c:\envs\lstm\lib\site-packages\keras\legacy\interfaces.py in wrapper(*args, **kwargs)
     85                 warnings.warn('Update your `' + object_name +
     86                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 87             return func(*args, **kwargs)
     88         wrapper._original_function = func
     89         return wrapper

c:\envs\lstm\lib\site-packages\keras\engine\training.py in evaluate_generator(self, generator, steps, max_queue_size, workers, use_multiprocessing)
   2278 
   2279             while steps_done < steps:
-> 2280                 generator_output = next(output_generator)
   2281                 if not hasattr(generator_output, '__len__'):
   2282                     raise ValueError('Output of generator should be a tuple '

c:\envs\lstm\lib\site-packages\keras\utils\data_utils.py in get(self)
    733             success, value = self.queue.get()
    734             if not success:
--> 735                 six.reraise(value.__class__, value, value.__traceback__)

c:\envs\lstm\lib\site-packages\six.py in reraise(tp, value, tb)
    691             if value.__traceback__ is not tb:
    692                 raise value.with_traceback(tb)
--> 693             raise value
    694         finally:
    695             value = None

c:\envs\lstm\lib\site-packages\keras\utils\data_utils.py in data_generator_task()
    633                 try:
    634                     if self._use_multiprocessing or self.queue.qsize() < max_queue_size:
--> 635                         generator_output = next(self._generator)
    636                         self.queue.put((True, generator_output))
    637                     else:

<ipython-input-5-b23230fc675a> in generator(df, batch_size, frames_per_scene)
     18             for j in range(frames_per_scene):
     19                 label = [steer.iloc[index], throttle.iloc[index]]
---> 20                 img_name = img_list[index]
     21 
     22                 pil_img = image.load_img(img_name)

c:\pyenvs\ca\lib\site-packages\pandas\core\series.py in __getitem__(self, key)
    765         key = com._apply_if_callable(key, self)
    766         try:
--> 767             result = self.index.get_value(self, key)
    768 
    769             if not is_scalar(result):

c:\envs\lstm\lib\site-packages\pandas\core\indexes\base.py in get_value(self, series, key)
   3116         try:
   3117             return self._engine.get_value(s, k,
-> 3118                                           tz=getattr(series.dtype, 'tz', None))
   3119         except KeyError as e1:
   3120             if len(self) > 0 and self.inferred_type in ['integer', 'boolean']:

pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_value()

pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_value()

pandas\_libs\index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas\_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.Int64HashTable.get_item()

pandas\_libs\hashtable_class_helper.pxi in pandas._libs.hashtable.Int64HashTable.get_item()

KeyError: 0
  

问题:有人遇到过类似的错误吗?

1 个答案:

答案 0 :(得分:1)

我认为此错误可能来自您为数据框建立索引的方式。验证您在df中的索引为0。

一种解决方案可能是将基础的numpy数组存储到img_listhappinessangerpandas.Series对象中。

这将给出:

def generator(df, batch_size, frames_per_scene=frames_per_scene):

    ### read data frame columns 
    # inputs
    img_list = df['filename'].values
    # outputs
    happiness= df['happiness'].values
    anger = df['anger'].values

    # create empty arrays for input and output
    batch_img = np.zeros((batch_size, frames_per_scene) + img_shape)
    batch_label = np.zeros((batch_size, 2))

    index = 0

    while True:
        for i in range(batch_size):          
            for j in range(frames_per_scene):
                label = [happiness[index], anger[index]]
                img_name = img_list[index]

                pil_img = image.load_img(img_name)

                batch_img[i, j, :] = image.img_to_array(pil_img)
            batch_label[i] = label

            index += 1
            if index == len(img_list):
                img_list = df['filename'].values
                happiness = df['happiness'].values
                anger = df['anger'].values

                index = 0
        yield batch_img, batch_label