关于以下论文:Don't Decay the Learning Rate, Increase the Batch Size
TL; DR 如何设置生成器,该生成器会随着时期的增加而增加批量大小?
(仅在您想帮助编辑代码时进一步阅读)
目标:让训练数据集(用于回归)(x_train,y_train)为Keras中内置的ANN实现批处理生成器。 代码的主要思想是:
我假设model.fit_generator(data_gen(x_train, y_train))
到达最后一批时,它将进入新的训练时代。
任何人都可以提出一个想法来检查批生成器的实现是否正确吗?这是我的代码,最后给出一个错误,我真的不太了解如何使其正常工作。
代码
def data_gen(features, targets):
global epoch
batches_produced_in_current_epoch = 1
epoch += 1
while True:
print('==============================================================================')
total_amount_of_samples = features.shape[0]
if epoch <= 2:
batch_size = 2 ** 4
elif epoch <= 3:
batch_size = 2 ** 5
elif epoch <= 2000:
batch_size = 2 ** 6
how_many_batches_in_this_epoch_total = features.shape[0] / batch_size
if not how_many_batches_in_this_epoch_total.is_integer():
how_many_batches_in_this_epoch_total = int(how_many_batches_in_this_epoch_total) + 1
else:
how_many_batches_in_this_epoch_total = int(how_many_batches_in_this_epoch_total)
print('Batch size:\t'+str(batch_size))
print('Batches produced in the current epoch:\t'+str(batches_produced_in_current_epoch))
print('How many batches in the epoch:\t{}'.format(how_many_batches_in_this_epoch_total))
if int(batches_produced_in_current_epoch) == int(how_many_batches_in_this_epoch_total - 1):
print('hi')
batch_x = np.zeros((int(features.shape[0] % batch_size), features.shape[1]))
batch_y = np.zeros((int(targets.shape[0] % batch_size), targets.shape[1]))
else:
batch_x = np.zeros((batch_size, features.shape[1]))
batch_y = np.zeros((batch_size, targets.shape[1]))
print('batch sizes:\t{}, {}'.format(batch_x.shape, batch_y.shape))
if int(batches_produced_in_current_epoch) == int(how_many_batches_in_this_epoch_total - 1):
batch_x[:,:] = x_train[batches_produced_in_current_epoch*batch_size:,:]
batch_y[:,:] = y_train[batches_produced_in_current_epoch*batch_size:,:]
else:
batch_x[:,:] = x_train[batches_produced_in_current_epoch*batch_size:(batches_produced_in_current_epoch + 1)*batch_size,:]
batch_y[:,:] = y_train[batches_produced_in_current_epoch*batch_size:(batches_produced_in_current_epoch + 1)*batch_size,:]
print('Shapes:\t{}'.format(batch_x.shape, batch_y.shape))
print('Batch size:\t'+str(batch_size))
print('Batches produced in the current epoch:\t'+str(batches_produced_in_current_epoch))
print('How many batches in the epoch:\t{}'.format(how_many_batches_in_this_epoch_total))
batches_produced_in_current_epoch += 1
if batches_produced_in_current_epoch == how_many_batches_in_this_epoch_total:
epoch += 1
yield batch_x, batch_y
然后,针对:
代码
import numpy as np
x_train = np.random.randn(20,6)
y_train = np.random.randn(20,1)
epoch = 0
for x, y in data_gen(x_train, y_train):
print(x.shape, y.shape)
我得到输出:
输出和错误消息
==============================================================================
Batch size: 16
Batches produced in the current epoch: 1
How many batches in the epoch: 2
hi
batch sizes: (4, 6), (4, 1)
Shapes: (4, 6)
Batch size: 16
Batches produced in the current epoch: 1
How many batches in the epoch: 2
(4, 6) (4, 1)
==============================================================================
Batch size: 16
Batches produced in the current epoch: 2
How many batches in the epoch: 2
batch sizes: (16, 6), (16, 1)
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-145-2883bc9d95ea> in <module>()
----> 1 for x, y in data_gen(x_train, y_train):
2 print(x.shape, y.shape)
<ipython-input-144-191d762315fa> in data_gen(features, targets)
57 batch_y[:,:] = y_train[batches_produced_in_current_epoch*batch_size:,:]
58 else:
---> 59 batch_x[:,:] = x_train[batches_produced_in_current_epoch*batch_size:(batches_produced_in_current_epoch + 1)*batch_size,:]
60 batch_y[:,:] = y_train[batches_produced_in_current_epoch*batch_size:(batches_produced_in_current_epoch + 1)*batch_size,:]
61
ValueError: could not broadcast input array from shape (0,6) into shape (16,6)
谢谢。