使用Keras LSTM预测使用批量培训后的单个示例

时间:2017-03-22 19:48:08

标签: neural-network keras lstm recurrent-neural-network

我有一个使用批量培训培训的网络模型。一旦训练完成,我想预测一个例子的输出。

这是我的型号代码:

model = Sequential()
model.add(Dense(32, batch_input_shape=(5, 1, 1)))
model.add(LSTM(16, stateful=True))
model.add(Dense(1, activation='linear'))
model.compile(loss='mean_squared_error', optimizer='adam', metrics=['accuracy'])

我有一系列单输入到单输出。我正在做一些测试代码来将字符映射到下一个字符(A-> B,B-> C等)。

我创建了一个形状(15,1,1)的输入数据和一个形状(15,1)的输出数据并调用函数:

model.fit(x, y, nb_epoch=epochs, batch_size=5, shuffle=False, verbose=0)

模型训练,现在我想拍摄一个角色并预测下一个角色(输入A,它预测B)。我创建了一个形状(1,1,1)的输入并调用:

pred = model.predict(x, batch_size=1, verbose=0)

这给出了:

ValueError: Shape mismatch: x has 5 rows but z has 1 rows

我看到一个解决方案是将“虚拟数据”添加到预测值,因此预测的输入形状为(5,1,1),数据为[x 0 0 0 0],您只需要输出的第一个元素作为您的值。但是,在处理较大批次时,这似乎效率低下。

我还尝试从模型创建中删除批量大小,但我收到以下消息:

ValueError: If a RNN is stateful, a complete input_shape must be provided (including batch size).

还有其他方法吗?谢谢你的帮助。

2 个答案:

答案 0 :(得分:0)

目前(Keras v2.0.8)在批量训练后,需要花费更多精力才能对单行进行预测。

基本上,batch_size在训练时固定,并且在预测时必须相同。

现在的解决方法是从训练模型中获取权重,并将其用作刚刚创建的新模型中的权重,其中batch_size为1。

快速代码是

model = create_model(batch_size=64)
mode.fit(X, y)
weights = model.get_weights()
single_item_model = create_model(batch_size=1)
single_item_model.set_weights(weights)
single_item_model.compile(compile_params)

这是一篇更深入的博客文章: https://machinelearningmastery.com/use-different-batch-sizes-training-predicting-python-keras/

我过去曾使用过这种方法在预测时有多个模型 - 一个是对大批量进行预测,一个是对小批量进行预测,另一个是对单个项目进行预测。由于批处理预测效率更高,因此我们可以灵活地接收任意数量的预测行(不仅仅是一个可被batch_size整除的数字),同时仍可快速获得预测。

答案 1 :(得分:0)

@ClimbsRocks显示了一个不错的解决方法。我无法提供“这是Keras打算完成的方式”的“正确”答案,但是我可以分享另一个替代方法,这可能会因用例而有所帮助。

在此替代方法中,我使用predict_on_batch()。此方法允许从批中传递单个样本而不会引发错误。不幸的是,它会根据训练设置返回目标形状的向量。但是,目标中的每个样本都会产生单个样本的预测。

您可以这样访问它:

to_predict = #Some single sample that would be part of a batch (has to have the right shape)#
model.predict_on_batch(to_predict)[0].flatten() #Flatten is optional

预测结果与将整个批次传递给predict()完全相同。


下面是一些鳕鱼的例子。 该代码来自my question,该代码也处理了此问题(但方式有所不同)。

sequence_size      = 5
number_of_features = 1
input              = (sequence_size, number_of_features)
batch_size         = 2

model = Sequential()

#Of course you can replace the Gated Recurrent Unit with a LSTM-layer
model.add(GRU(100, return_sequences=True, activation='relu', input_shape=input, batch_size=2, name="GRU"))
model.add(GRU(1, return_sequences=True, activation='relu', input_shape=input, batch_size=batch_size, name="GRU2"))
model.compile(optimizer='adam', loss='mse')

model.summary()

#Summary-output:
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
GRU (GRU)                    (2, 5, 100)               30600     
_________________________________________________________________
GRU2 (GRU)                   (2, 5, 1)                 306       
=================================================================
Total params: 30,906
Trainable params: 30,906
Non-trainable params: 0


def generator(data, batch_size, sequence_size, num_features):
    """Simple generator"""
    while True:
        for i in range(len(data) - (sequence_size * batch_size + sequence_size) + 1):
            start = i
            end   = i + (sequence_size * batch_size)

            yield data[start : end].reshape(batch_size, sequence_size, num_features), \
                    data[end - ((sequence_size * batch_size) - sequence_size) : end + sequence_size].reshape(batch_size, sequence_size, num_features)

#Task: Predict the continuation of a linear range
data = np.arange(100)
hist = model.fit_generator(
                generator=generator(data, batch_size, sequence_size, num_features, False),
                steps_per_epoch=total_batches,
                epochs=200,
                shuffle=False
            )

to_predict = np.asarray([[np.asarray([x]) for x in range(95,100,1)]]) #Only single element of a batch
correct    = np.asarray([100,101,102,103,104])
print( model.predict_on_batch(to_predict)[0].flatten() )

#Output:
[ 99.92908 100.95854 102.32129 103.28584 104.20213 ]