使用Keras进行深度学习:如何理解Lambda图层和lambda函数?

时间:2017-07-08 10:01:23

标签: lambda keras

我的代码是这样的:

labels = Input(name='the_labels', shape=[1], dtype='float32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')

loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name="ctc")([output, labels, input_length, label_length])

model = Model(input=[net_input, labels, input_length, label_length],  output=[loss_out])
model.compile(loss={'ctc': lambda y_true, y_pred: y_pred},  optimizer=optimizer, metrics=[])

我的ctc_lambda_func定义如下:

def ctc_lambda_func(args):
  y_pred, labels, input_length, label_length = args
  # the 2 is critical here since the first couple outputs of the RNN
  # tend to be garbage:
  shift = 2
  y_pred = y_pred[:, shift:, :]
  input_length -= shift
 return K.ctc_batch_cost(labels, y_pred, input_length, label_length)

我无法理解:

  1. loss_out = Lambda(ctc_lambda_func,output_shape =(1,), name =“ctc”)([output,labels,input_length,label_length])
  2. 有人说Lambda是一个图层,但我该如何使用这个图层?

    (ctc_lambda_func, output_shape=(1,), name="ctc")是函数“Lambda”的参数 - 但参数([output, labels, input_length, label_length])是什么?

    1. loss={'ctc': lambda y_true, y_pred: y_pred}是损失函数。我发现丢失函数列表如下,但没有ctc。
      • mean_squared_error(y_true,y_pred)mean_absolute_error(y_true, y_pred)
      • mean_absolute_percentage_error(y_true,y_pred)
      • mean_squared_logarithmic_error(y_true,y_pred)
      • squared_hinge(y_true,y_pred)
      • 铰链(y_true,y_pred)
      • categorical_hinge(y_true,y_pred)
      • logcosh(y_true,y_pred)
      • categorical_crossentropy(y_true,y_pred)
      • sparse_categorical_crossentropy(y_true,y_pred)
      • binary_crossentropy(y_true,y_pred)
      • kullback_leibler_divergence(y_true,y_pred)
      • poisson(y_true,y_pred) -cosine_proximity(y_true,y_pred)
    2. 我是Keras和Python的新手。如果你能给我一些解释,我非常感激。

1 个答案:

答案 0 :(得分:2)

这不是Keras库的正确用法。该代码似乎绕过了API shortcoming

  1. Lambda图层通常用于实现自定义函数,作为Keras中计算图的一部分。 ([output, labels, input_length, label_length])是传递给自定义函数的张量,在此函数中。这个复杂解决方案背后的原因是API不仅允许(output, labels)签名用于丢失函数。
  2. 现在实现了损失函数,代码的第二部分通过返回预测张量y_pred来绕过Keras内置损失函数。 loss={'ctc':...层的'ctc'被绕过损失,即Lambda。
  3. 您可以实现自定义丢失功能,但这不是API的预期用途。对于具有多个参数的丢失解决方案,请查看此question