我最近开始使用Keras
,在他们的文档中,只显示了几行代码
inp = Input(shape=(2,))
hl_x = Dense(4, activation='tanh', name= 'First_Hidden_Layer_Following_Input' )(inp)
其中
type(Input)
>> function
type(inp)
>>>tensorflow.python.framework.ops.Tensor
Input
是一个函数,inp
是类型为tensor
的变量
这是什么意思,它是如何工作的?
答案 0 :(得分:2)
我不是Keras的专家,但我会尝试突出显示它。
首先,Keras中的所有图层都是可调用对象,例如他们定义了__call__方法。这是什么意思?这样的类可以用作函数:
x = np.random.randint(0, 10, (10,10))
functor = Layer()
res = functor(x)
这本身不是Keras的功能,只是一般的Python语法。由于与函数调用相比,对象的生命周期可能更长,因此您可能会在对象内部积累一些中间数据,例如一层可以保留所有相关的渐变。
第二,我猜是,但是我不确定,这种方法可以解决性能问题。当您定义模型时,不会发生太多事情。实际上,您只需将各层之间的输入和输出链接到一个有向图/网络中...就这样,仅此而已,就计算资源而言这非常便宜,您只需通过传递 inp来定义模型的结构和 h1_x 在各层之间,每一层只是将其注册为自己的输入/输出。所有的魔术和沉重的工作都会在以后发生-在model.compile()和实际训练/推断阶段。
答案 1 :(得分:1)
Dense(....)
返回对象that can be __called__()
,类似于参数化函数:
def print_multiple(k):
"""Returns a function that prints 'k' times whatever you give it."""
return lambda x: print(*(x for _ in range(k)))
print_multiple(6)("Merry")
print_multiple(4)("Christmas")
打印
Merry Merry Merry Merry Merry Merry
Christmas Christmas Christmas Christmas
keras.layers.dense是一个可调用对象-大致如下:
class PrintMult:
"""Object that prints 'how_often' times whatever you give it."""
def __init__(self, how_often):
self.how_often = how_often
def __call__(self, what_ever):
print(*(what_ever for _ in range(self.how_often)))
PrintMult(5)("Yeeha") # Yeeha Yeeha Yeeha Yeeha Yeeha