我想了解RNN,特别是LSTM如何使用Keras和Tensorflow处理多个输入维度。我的意思是输入形状是(batch_size,timesteps,input_dim)其中input_dim> 1.
我认为如果input_dim = 1,下面的图像很好地说明了LSTM的概念
这是否意味着如果input_dim> 1然后x不再是单个值而是一个数组?但如果它是这样的,那么权重也会变成数组,形状与x +上下文相同?
答案 0 :(得分:3)
Keras创建了一个计算图,可以在每个要素的底部图片中执行序列(但对于所有单位)。这意味着状态值C始终是标量,每单位一个。它不会立即处理功能,它会立即处理单元,并单独提供功能。
import keras.models as kem
import keras.layers as kel
model = kem.Sequential()
lstm = kel.LSTM(units, input_shape=(timesteps, features))
model.add(lstm)
model.summary()
free_params = (4 * features * units) + (4 * units * units) + (4 * num_units)
print('free_params ', free_params)
print('kernel_c', lstm.kernel_c.shape)
print('bias_c', lstm.bias_c .shape)
其中4
表示底部图片中f,i,c和o内部路径中的每一个的一个。第一项是内核的权重数,第二项是重复内核的权重,最后一项是偏差(如果应用)。对于
units = 1
timesteps = 1
features = 1
我们看到了
Layer (type) Output Shape Param #
=================================================================
lstm_1 (LSTM) (None, 1) 12
=================================================================
Total params: 12.0
Trainable params: 12
Non-trainable params: 0.0
_________________________________________________________________
num_params 12
kernel_c (1, 1)
bias_c (1,)
和
units = 1
timesteps = 1
features = 2
我们看到了
Layer (type) Output Shape Param #
=================================================================
lstm_1 (LSTM) (None, 1) 16
=================================================================
Total params: 16.0
Trainable params: 16
Non-trainable params: 0.0
_________________________________________________________________
num_params 16
kernel_c (2, 1)
bias_c (1,)
其中bias_c
是状态C的输出形状的代理。请注意,关于单元的内部制作有不同的实现。详细信息在这里(http://deeplearning.net/tutorial/lstm.html),默认实现使用Eq.7。希望这会有所帮助。
答案 1 :(得分:0)
让我们将以上答案更新为TensorFlow 2。
import tensorflow as tf
model = tf.keras.Sequential([tf.keras.layers.LSTM(units, input_shape=(timesteps, features))])
model.summary()
free_params = (4 * features * units) + (4 * units * units) + (4 * num_units)
print('free_params ', free_params)
print('kernel_c', lstm.kernel_c.shape)
print('bias_c', lstm.bias_c .shape)
使用此代码,您也可以在TensorFlow 2.x中获得相同的结果。