如何在培训期间更改批量大小?

时间:2018-05-30 14:16:02

标签: keras implementation

在培训期间,在每个时代,我想更改批量大小(用于Check if element exists in jQuery目的)。 创建自定义Callback似乎合适,但batch_size不是Model类的成员。

我看到的唯一方法是覆盖experimental并在每个循环中将batch_size公开给回调。如果不使用回调,是否有更简洁或更快捷的方法?

3 个答案:

答案 0 :(得分:3)

我认为最好使用自定义数据生成器来控制传递给训练循环的数据,这样您就可以生成不同大小的批量,动态处理数据等。这是一个大纲:< / p>

def data_gen(data):
  while True: # generator yields forever
    # process data into batch, it could be any size
    # it's your responsibility to construct a batch
    yield x,y # here x and y are a single batch

现在您可以使用model.fit_generator(data_gen(data), steps_per_epoch=100)进行训练,每个时期将产生100个批次。如果要将其封装在类中,也可以使用Sequence

答案 1 :(得分:0)

对于在此降落的其他人,我发现在Keras中进行批量大小调整的最简单方法是多次调用fit(使用不同的批量大小):

model.fit(X_train, y_train, batch_size=32, epochs=20)

# ...continue training with a larger batch size
model.fit(X_train, y_train, batch_size=512, epochs=10) 

答案 2 :(得分:0)

在大多数情况下,可接受的答案是最好的,请勿更改批次大小。这个问题可能有99%的更好的解决方法。

对于那些确实有例外情况的1%族人,在网络中部更改批量大小是适当的,有一个git讨论可以很好地解决这个问题:

https://github.com/keras-team/keras/issues/4807

总结一下:Keras不想更改批次大小,因此您需要作弊并添加一个尺寸,并告诉keras它的batch_size为1。例如,调整10张cifar10图像的大小[10, 32, 32, 3],现在变成[1, 10, 32, 32, 3]。您需要在整个网络中适当地重塑它。使用tf.expand_dimstf.squeeze来添加和删除维。