Tensoflow2 LSTM-未使用的参数input_shape吗?

时间:2020-05-02 20:40:56

标签: python tensorflow keras deep-learning lstm

所以我用以下代码构建了神经网络:

import tensorflow as tf

tf_model = tf.keras.Sequential()
tf_model.add(tf.keras.layers.LSTM(50, activation='relu'))
tf_model.add(tf.keras.layers.Dense(20, activation='relu'))
tf_model.add(tf.keras.layers.Dense(10, activation='relu'))
tf_model.add(tf.keras.layers.Dense(1, activation='linear'))
tf_model.compile(optimizer='Adam', loss='mse')

我的训练集的形状如下:

>> ts_train_X.shape
(16469, 3, 21)

我在这里阅读了许多有关stackoverflow的文章和问题,以便为LSTM提供合适的数据框。我发现几乎每个页面都指定了input_shape参数,并将其传递给LSTM(..)或Sequential(..)。

当我查看LSTM API时,找不到该参数的引用。我也对源代码有所了解,在我看来,形状是可以自动推断的,但是对此我不确定。

这使我想到了一个问题:为什么我的代码有效?如果不指定input_shape参数,那么作为第一层的LSTM层如何知道输入的形状?


编辑:根据评论中的建议更改标题。

1 个答案:

答案 0 :(得分:2)

参数input_shape可以提供给任何keras Layer子类的构造函数,因为这是API的定义方式。

代码之所以有效,是因为input_shape作为关键字参数(**kwargs)传递,然后这些关键字参数由LSTM构造函数传递给Layer构造函数,然后,它继续存储信息以供以后使用。这实际上意味着不必在每一层中都定义input_shape参数,而是将其作为关键字参数传递。

我认为问题在于,由于keras已移至tensorflow,因此文档可能不完整。您可以在Guide to the Sequential API中找到有关input_shape参数的更多信息。