加载使用Lambda层的Keras / Tensorflow模型时出现细分违规(SIGSEGV)

时间:2019-08-21 17:56:49

标签: tensorflow keras

我创建这样的模型并保存模型:

def extract_and_duplicate(tensor, reps=1, batch_size=0, sample_size=0):
    tensor = K.reshape(tensor[:,:,0],(batch_size, sample_size, 1))
    if reps > 1:
        tensor = Concatenate()([tensor for i in range(reps)])
    return tensor

input = Input(batch_shape = (batch_size, sample_size, num_features))
out = <steps to create a NN with several layers>

pre_mask = Lambda(extract_and_duplicate, arguments = {'reps': some_number,'batch_size': batch_size, 'sample_size': sample_size})(input)
mask = TimeDistributed(Dense(m, activation = 'tanh'))(out)
out = Multiply()([pre_mask,mask])
model = Model(input, out)

当我在下面的行中加载模型时,会收到SIGSEGV信号。

load_model(model_path, custom_objects={'estimated_accuracy': estimated_accuracy, 'extract_and_duplicate': extract_and_duplicate})

我单步执行load_model方法,发现在加载Lambda层时,SIGSEGV失败。 当我删除Lambda层时,加载模型即可。 我是在做错什么还是在踩keras / tensorflow错误?

您可以提供解决方案或调查步骤吗? 谢谢

1 个答案:

答案 0 :(得分:1)

您需要将创建lambda层所需的所有内容传递给custom_objects。否则Keras将不知道这些变量名的含义。

因此,添加some_numberbatch_sizesample_sizereps以及将lambda层完全构建到custom_objects中所需的所有其他内容。