Keras set_weights在循环中变慢

时间:2018-11-07 18:30:29

标签: python tensorflow keras lstm

我正在Keras上训练LSTM模型,但想在训练过程中使用更大的批处理大小来加快GPU的处理速度。但是,我想执行一步预测(batch_size = 1),Keras自然不支持使用与训练中使用的批次大小不同的模型预测。但是如here所述,您可以创建一个指定批处理大小为1的新模型,然后将旧的权重复制到新模型中,然后进行预测。

但是,我要尽早停止并在每个时期之后通过一步预测来检查验证错误,这需要创建一个新模型并在每个时期之后复制权重。我注意到在循环中进行几次迭代后,花在复制权重上的时间增加了,但我不知道为什么。

我也只使用CPU尝试过相同的代码,但仍然遇到相同的问题。

以下是压缩代码:

# Create original model
model = Sequential()
model.add(LSTM(..., batch_input_shape=(32, x_train.shape[1], x_train.shape[2])))
model.compile(...)
# Count number of epochs before validation error increases after given patience
patience = 2
num_epochs = 0
min_error = 1000000
num_patience = 0
while True:
    print "Creating model"
    model.fit(..., epochs=1, batch_size=32)
    model.reset_states()
    print "Created model, creating new model"
    # Create new model with batch size of 1
    new_model = Sequential()
    new_model.add(LSTM(..., batch_input_shape=(1, x_train.shape[1], x_train.shape[2])))
    print "Copying weights"
    # Copy weights
    old_weights = model.get_weights()
    new_model.set_weights(old_weights)
    new_model.compile(...)
    print "Updating states"
    # Update states
    new_model.predict(x_train, batch_size=1)
    # Get predictions
    yhats = new_model.predict(x_test, batch_size=1)
    mse = mean_squared_error(ytrues, yhats)
    print mse
    # Check if we want to continue training
    if mse < min_error and abs(mse - min_error) > 0.0001:
        num_epochs = num_epochs + 1
        min_error = mse
        num_patience = 0
    else:
        if num_patience < patience:
            num_epochs = num_epochs + 1
            num_patience = num_patience + 1
        else:
            break

并查看打印语句

Creating model
Created model, creating new model
Copying weights <- starts out very fast
Updating states
0.17 <- mse
.
.
.
Creating model
Created model, creating new model
Copying weights <- slows down a lot
Updating states
0.09

可能是什么原因造成的?

0 个答案:

没有答案