具有tensorflow的Keras抛出ResourceExhaustedError

时间:2017-11-07 15:50:56

标签: tensorflow out-of-memory keras

出于研究目的,我正在训练一个神经网络,它根据时代的奇偶性不同地更新其权重:

1)如果纪元是偶数,则用反向传播改变NN的权重

2)如果纪元是奇数,则仅用update_weights_with_custom_function()更新模型,因此冻结网络。

以下是实现此功能的代码的简化部分(请注意epochs=1):

for epoch in range(nb_epoch):

    if epoch % 2 == 0:
        model.trainable = True    # Unfreeze the model
    else:
        model.trainable = False   # Freeze the model
    model.compile(optimizer=optim, loss=gaussian_loss, metrics=['accuracy'])

    hist = model.fit(X_train, Y_train, 
             batch_size=batch_size,
             epochs=1,
             shuffle=True,
             verbose=1,
             callbacks=[tbCallBack, csv_epochs, early_stop],
             validation_data=(X_val, Y_val))

    if epoch % 2 == 1:
        update_weights_with_custom_function()

问题:在几个时代之后,keras会抛出一个ResourceExhaustedError,但只有张量流,没有theano 。似乎循环compile()正在创建模型而不释放它们。

因此,我该怎么办?我知道K.clear_session()释放了内存,但它需要保存模型并重新加载它(see),这给我一些问题,因为load_model()在我的情况下不能开箱即用。< / p>

我也可以采取其他方式来实现我想要实现的目标(即根据时代的平价来冻结NN模型)。

摘要:带有张量流后端的keras正在抛出ResourceExhaustedError,因为我正在循环compile()

1 个答案:

答案 0 :(得分:0)

正如MarcinMożejko指出的那样,使用#include <openssl/rsa.h> #include <openssl/pem.h> #include <openssl/md5.h> #include <string.h> #include <stdio.h> using namespace std; int main(int argc, char* argv[]) { FILE* f; RSA* pRSAPRI = RSA_new(); f = fopen(argv[1], "r"); RSA *private_key = PEM_read_RSAPrivateKey(f, &pRSAPRI, NULL, NULL); unsigned char sourceText[100]; unsigned char cipher[2048]; strcpy((char*)sourceText, "some_val"); int ret = RSA_private_encrypt(25, sourceText, cipher, private_key, RSA_PKCS1_PADDING); if (ret < 0) { printf("RSA_private_encrypt failed\n"); exit (-1); } std::cout << " cipher = " << cipher << std::endl; unsigned char md5Result[MD5_DIGEST_LENGTH]; MD5((unsigned char*)&cipher, 2048, (unsigned char*)&md5Result); char mdString[33]; for(int i = 0; i < 16; i++) sprintf(&mdString[i*2], "%02x", (unsigned int)md5Result[i]); printf(" md5 : %s\n", mdString); return 0; } 正在完成我想要实现的目标。

我添加了一个自定义回调(灵感为here),这避免了eval()

的循环

即使没有直接解决张量流问题,问题现在也解决了。