为fit_generator()构建一个生成器,该生成器随着时代的增长而增加批处理大小

时间:2018-10-24 19:09:40

标签: python-3.x keras generator

关于以下论文:Don't Decay the Learning Rate, Increase the Batch Size

TL; DR 如何设置生成器,该生成器会随着时期的增加而增加批量大小?

(仅在您想帮助编辑代码时进一步阅读)


目标:让训练数据集(用于回归)(x_train,y_train)为Keras中内置的ANN实现批处理生成器。 代码的主要思想是:

  1. 哪个时代?根据答案,设置批次大小
  2. 直到数据集的末尾,生成具有给定批处理大小的批处理(最后一个批处理除外,后者可能较小)

我假设model.fit_generator(data_gen(x_train, y_train))到达最后一批时,它将进入新的训练时代。

任何人都可以提出一个想法来检查批生成器的实现是否正确吗?这是我的代码,最后给出一个错误,我真的不太了解如何使其正常工作。

代码

def data_gen(features, targets):
    global epoch
    batches_produced_in_current_epoch = 1
    epoch += 1

    while True:
        print('==============================================================================')
        total_amount_of_samples = features.shape[0]

        if epoch <= 2:
            batch_size = 2 ** 4
        elif epoch <= 3:
            batch_size = 2 ** 5
        elif epoch <= 2000:    
            batch_size = 2 ** 6 

        how_many_batches_in_this_epoch_total = features.shape[0] / batch_size

        if not how_many_batches_in_this_epoch_total.is_integer():
            how_many_batches_in_this_epoch_total = int(how_many_batches_in_this_epoch_total) + 1
        else:
            how_many_batches_in_this_epoch_total = int(how_many_batches_in_this_epoch_total)

        print('Batch size:\t'+str(batch_size))
        print('Batches produced in the current epoch:\t'+str(batches_produced_in_current_epoch))
        print('How many batches in the epoch:\t{}'.format(how_many_batches_in_this_epoch_total))


        if int(batches_produced_in_current_epoch) == int(how_many_batches_in_this_epoch_total - 1):
            print('hi')
            batch_x = np.zeros((int(features.shape[0] % batch_size), features.shape[1]))
            batch_y = np.zeros((int(targets.shape[0] % batch_size), targets.shape[1]))
        else:
            batch_x = np.zeros((batch_size, features.shape[1]))
            batch_y = np.zeros((batch_size, targets.shape[1]))

        print('batch sizes:\t{}, {}'.format(batch_x.shape, batch_y.shape))

        if int(batches_produced_in_current_epoch) == int(how_many_batches_in_this_epoch_total - 1):
            batch_x[:,:] = x_train[batches_produced_in_current_epoch*batch_size:,:]
            batch_y[:,:] = y_train[batches_produced_in_current_epoch*batch_size:,:]
        else:
            batch_x[:,:] = x_train[batches_produced_in_current_epoch*batch_size:(batches_produced_in_current_epoch + 1)*batch_size,:]
            batch_y[:,:] = y_train[batches_produced_in_current_epoch*batch_size:(batches_produced_in_current_epoch + 1)*batch_size,:]

        print('Shapes:\t{}'.format(batch_x.shape, batch_y.shape))
        print('Batch size:\t'+str(batch_size))
        print('Batches produced in the current epoch:\t'+str(batches_produced_in_current_epoch))
        print('How many batches in the epoch:\t{}'.format(how_many_batches_in_this_epoch_total))

        batches_produced_in_current_epoch += 1

        if batches_produced_in_current_epoch == how_many_batches_in_this_epoch_total:
            epoch += 1    



        yield batch_x, batch_y

然后,针对:

代码

import numpy as np 

x_train = np.random.randn(20,6)
y_train = np.random.randn(20,1)
epoch = 0
for x, y in data_gen(x_train, y_train):
    print(x.shape, y.shape)

我得到输出:

输出和错误消息

==============================================================================
Batch size: 16
Batches produced in the current epoch:  1
How many batches in the epoch:  2
hi
batch sizes:    (4, 6), (4, 1)
Shapes: (4, 6)
Batch size: 16
Batches produced in the current epoch:  1
How many batches in the epoch:  2
(4, 6) (4, 1)
==============================================================================
Batch size: 16
Batches produced in the current epoch:  2
How many batches in the epoch:  2
batch sizes:    (16, 6), (16, 1)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-145-2883bc9d95ea> in <module>()
----> 1 for x, y in data_gen(x_train, y_train):
      2     print(x.shape, y.shape)

<ipython-input-144-191d762315fa> in data_gen(features, targets)
     57             batch_y[:,:] = y_train[batches_produced_in_current_epoch*batch_size:,:]
     58         else:
---> 59             batch_x[:,:] = x_train[batches_produced_in_current_epoch*batch_size:(batches_produced_in_current_epoch + 1)*batch_size,:]
     60             batch_y[:,:] = y_train[batches_produced_in_current_epoch*batch_size:(batches_produced_in_current_epoch + 1)*batch_size,:]
     61 

ValueError: could not broadcast input array from shape (0,6) into shape (16,6)

谢谢。

0 个答案:

没有答案