为什么在layers.embedded在keras张量流中需要input_length?

时间:2020-05-17 08:14:34

标签: tensorflow keras reshape tensorflow2.0 tf.keras

layers.embedding具有documentation描述为的参数(input_length):

input_length:输入序列的长度,如果为常数。如果要在上游连接Flatten然后连接Dense层,则必须使用此参数(否则,将无法计算密集输出的形状)。

为什么无法计算密集输出的形状。对我来说,Flatten似乎很容易做到。它只是一个tf.rehshape(input,(-1,1)),后跟一个具有我们选择的任意输出形状的密集层。

您能帮助我指出对整体逻辑的理解上的失误吗?

1 个答案:

答案 0 :(得分:2)

通过指定尺寸,可以确保模型接收到固定长度的输入。

从技术上讲,您可以将None放在所需的任何输入尺寸处。形状将在运行时推断。

您只需要确保指定了图层参数(input_dim,output_dim),kernel_size(用于转换层),单位(用于FC层)。

如果使用Input并指定通过网络传递张量的形状,则可以计算形状。

例如以下模型是完全有效的:

from tensorflow.keras import layers
from tensorflow.keras import models

ip = layers.Input((10))
emb = layers.Embedding(10, 2)(ip)
flat = layers.Flatten()(emb)
out = layers.Dense(5)(flat)

model = models.Model(ip, out)

model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 10)]              0         
_________________________________________________________________
embedding (Embedding)        (None, 10, 2)             20        
_________________________________________________________________
flatten (Flatten)            (None, 20)                0         
_________________________________________________________________
dense (Dense)                (None, 5)                 105       
=================================================================
Total params: 125
Trainable params: 125
Non-trainable params: 0

在这里,我没有指定input_length,但是它是从Input层推断出来的。

问题在于顺序API,如果您既未在Input层中又未在嵌入层中指定输入形状,则无法使用适当的参数集来构建模型。< / strong>

例如,

from tensorflow.keras import layers
from tensorflow.keras import models

model = models.Sequential()
model.add(layers.Embedding(10, 2, input_length = 10)) # will be an error if I don't specify input_length here as there is no way to know the shape of the next layers without knowing the length

model.add(layers.Flatten())
model.add(layers.Dense(5))


model.summary()

在此示例中,必须指定input_length,否则模型将引发错误。