我是Java开发人员,对Python和Keras还是陌生的。我有一个使用此代码的工作示例:
encoder_lstm = LSTM(self.latent_dim, return_state=True)
encoder_outputs, state_h, state_c = encoder_lstm(encoder_inputs)
因此,我知道我正在创建类LSTM
的对象的第一行,但是随后我向该对象传递了ndarray
(或类似的对象),但没有指定任何要调用的函数。
我怎么知道这里正在调用什么函数?我猜想我需要看看def call
,但是call
是Python或Keras的“默认”功能吗?
答案 0 :(得分:2)
在Python中创建类时,可以为该类定义__call__
方法。这样,该类的实例化对象在调用时就充当了函数:
class MyClass():
# ...
def __call__(self, *args, **kwargs):
print("The object was called!")
>>> obj = MyClass()
>>> obj()
"The object was called!"
现在,如果您查看Keras的源代码,您会发现Keras中所有层都继承自该类的基础层类(即Layer
)具有{{3} }方法:
def __call__(self, inputs, **kwargs):
"""Wrapper around self.call(), for handling internal references.
此方法对输入进行一些检查并更新内部引用,然后调用该层的call
方法。这就是为什么在__call__
时,您只需要重写其针对Keras的call
方法(而不是__call__
)。
现在,当您在Keras中创建这样的图层时:
encoder_lstm = LSTM(self.latent_dim, return_state=True)
,然后像这样在输入张量(不是numpy数组)上调用它:
encoder_outputs, state_h, state_c = encoder_lstm(encoder_inputs)
基本上首先调用基础层的__call__
方法,该方法内部调用相应层的call
方法,在此示例中,该方法为LSTM
层。 call
方法是该层的所有逻辑(即计算逻辑)所在的位置。
答案 1 :(得分:1)
Python中的某些对象是“可调用的”。
确实有一个为可调用对象实现的标准方法,但是它不是在Keras代码中看到的call
。这是一种__call__
方法。 (对于此类标准方法,Python在前后使用两个带有下划线的表示法,例如__init__
,它是构造函数方法)
在Keras中,您只会在 base_layer 中找到__call__
方法:https://github.com/keras-team/keras/blob/master/keras/engine/base_layer.py/#L382
在内部,此方法最终将在派生层中调用call
(不是Python标准,但对于所有Keras层都是必需的)方法。因此,如果您正在研究Keras的LSTM代码,或者正在创建自定义层,那么只需查看call
方法即可。它包含了解网络数学所必需的张量运算,而没有所有开销。