keras不考虑batch_input参数

时间:2016-09-14 14:31:06

标签: keras

我正在使用keras训练神经网络,似乎没有正确解释batch_size参数。

请参阅下面的代码(应用程序很愚蠢,我关心的是输出)。

import numpy as np 
from keras.models import Sequential
from keras.layers import Activation, Dense, Reshape
import keras 

class LossHistory(keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.losses = []

    def on_batch_end(self, batch, logs={}):
        self.losses.append(logs.get('loss'))

history = LossHistory()


X = np.random.normal(0, 1, (1000, 2))
Y = np.random.normal(0, 1, (1000, 3))

model = Sequential()
model.add(Dense(20, input_shape = (2,), name='input layer dude'))
model.add(Activation('relu'))
model.add(Dense(12))
model.add(Activation('relu'))
model.add(Dense(8))
model.add(Activation('linear'))
model.add(Dense(3))
model.add(Activation('linear'))
model.add(Reshape(target_shape=(3,), name='output layer dude'))
model.compile(optimizer='adam', loss='mse', )

当我通过以下方式调用此模型时:

model.fit(X, Y, batch_size=10, nb_epoch=10, callbacks=[history])

输出似乎表明它不是每批10个项目,而是1000个(这是总样本数)。

Epoch 1/10
1000/1000 [==============================] - 0s - loss: 898.6197      
Epoch 2/10
1000/1000 [==============================] - 0s - loss: 31.5123     
Epoch 3/10
1000/1000 [==============================] - 0s - loss: 16.7140     
Epoch 4/10
1000/1000 [==============================] - 0s - loss: 11.4034     
Epoch 5/10
1000/1000 [==============================] - 0s - loss: 8.9275     
Epoch 6/10
1000/1000 [==============================] - 0s - loss: 7.4699     
Epoch 7/10
1000/1000 [==============================] - 0s - loss: 6.5648     
Epoch 8/10
1000/1000 [==============================] - 0s - loss: 5.9576     
Epoch 9/10
1000/1000 [==============================] - 0s - loss: 5.5064     
Epoch 10/10
1000/1000 [==============================] - 0s - loss: 5.1514     

有什么问题吗?

1 个答案:

答案 0 :(得分:0)

他实际上在考虑它。一个纪元是整个数据集的迭代,因此是1000/1000。

我将批量大小更改为128更具可读性并添加了回调以在每批次之后打印丢失,我得到的是这个(我还增加了数据量以获得更好的可读性):

class MBLossPrint(Callback):
    def on_batch_end(self, batch, logs={}):
        print ' mbloss', logs['loss'], 'lr', self.model.optimizer.lr.get_value()

如果您需要它,在批处理结束时打印一些东西的回调:

{{1}}

希望这会有所帮助:)