在Keras中使用ModelCheckpoint时最大递归深度错误

时间:2019-04-20 21:15:53

标签: keras

我正在训练具有约30,000个参数的模型,并且我想使用ModelCheckpoint回调在每个时期后保存模型的状态。

当前,我无法保存模型。我收到以下错误:

RuntimeError: maximum recursion depth exceeded while calling a Python object

我尝试了在其他地方(例如here)找到的解决方案,但无济于事。

import sys
sys.setrecursionlimit(10000)

如何解决这个问题?

1 个答案:

答案 0 :(得分:0)

啊哈!我找到了线索here ...

...并对其进行了修改。我有一个Lambda层,其最初编写如下:

    def sampling(args):
        z_mean, z_log_sigma = args
        epsilon = K.random_normal(shape=((self.batch_size - self.n_lags), self.hid_dim_2), mean=0., stddev=1.) 
        return z_mean + z_log_sigma * epsilon

    zlambda = Lambda(sampling, output_shape=(self.hid_dim_2,))([z_mean, z_log_sigma]) 

我将shape移到了全局变量,瞧!

    shape_val = ((self.batch_size - self.n_lags), self.hid_dim_2) #Need this to be a global variable, otherwise Recursion Depth Error of death.

    def sampling(args):
        z_mean, z_log_sigma = args
        epsilon = K.random_normal(shape=shape_val, mean=0., stddev=1.)
        return z_mean + z_log_sigma * epsilon

    zlambda = Lambda(sampling, output_shape=(self.hid_dim_2,))([z_mean, z_log_sigma]) #Saving and loading Lambda layers is weird.