了解LSTM细胞的递归神经网络的功能

时间:2017-07-11 07:22:29

标签: tensorflow lstm recurrent-neural-network

背景:

  • 我有一个带LSTM细胞的递归神经网络
  • 在我的情况下,网络的输入是一批大小(batch_size,number_of_timesteps,one_hot_encoded_class)(128,300,38)
  • 批次(1-128)的不同行不一定相关 彼此
  • 一个时间步的目标由下一个的值给出 时间步。

我的问题: 当我使用输入批次(128,300,38)和相同大小的目标批次训练网络时,

  1. 网络是否始终只考虑最后一个时间步 t 来预测下一个时间步的值 t + 1 < / EM>

  2. 还是考虑从序列开始到时间步 t 的所有时间步骤?

  3. 或者LSTM单元格内部记住以前的所有状态吗?

  4. 我对功能感到困惑,因为网络是在多个时间步骤上同时训练的,所以我不确定LSTM细胞如何仍然可以了解以前的状态。

    我希望有人可以提供帮助。提前谢谢!

    dicussion代码:

                cells = []
    
                for i in range(self.n_layers):
                    cell = tf.contrib.rnn.LSTMCell(self.n_hidden)
                    cells.append(cell)
    
                cell = tf.contrib.rnn.MultiRNNCell(cells)
                init_state = cell.zero_state(self.batch_size, tf.float32)
    
                outputs, final_state = tf.nn.dynamic_rnn(
                    cell, inputs=self.inputs, initial_state=init_state)
    
                self.logits = tf.contrib.layers.linear(outputs, self.num_classes)
    
                softmax_ce = tf.nn.sparse_softmax_cross_entropy_with_logits(
                    labels=labels, logits=self.logits)
    
                self.loss = tf.reduce_mean(softmax_ce)
                self.train_step = tf.train.AdamOptimizer(self.lr).minimize(self.loss)
    

2 个答案:

答案 0 :(得分:1)

enter image description here

以上是一个简单的RNN展开到神经元级别,有3个时间步长。

您可以看到时间步 t 的输出取决于从头开始的所有时间步。使用back-propagation through time训练网络,其中权重通过所有误差梯度随时间的贡献来更新。权重是跨时间共享的,因此在所有时间步骤中都没有simultaneous update

先前状态的知识通过状态变量 s_t 传输,因为它是先前输入的函数。因此,在任何时间步骤,基于当前输入以及状态变量捕获的先前输入的(函数)进行预测。

注意:由于简单,使用了基本rnn代替LSTM

答案 1 :(得分:0)

以下具体说明您的案例会有所帮助:

给定[128, 300, 38]

的输入形状
  • dynamic_rnn的一次调用将传播到所有300个步骤,如果您使用的是LSTM,那么状态也将通过这300个步骤进行。
  • 但是,对dynamic_rnn的每次SUBSEQUENT调用都不会自动记住上一次调用的状态。通过第二次调用,权重/等。由于第一次通话,它将被更新,但您仍需要将第一次通话产生的状态传递给第二次通话。这就是dynamic_rnn具有参数initial_state的原因,以及为什么其中一个输出为final_state(即处理一次调用中的所有300个步骤后的状态)。所以你打算从N调用最终状态并将其作为调用N + 1的初始状态传递给dynamic_rnn。这与LSTM特别相关,因为这就是你要求的
  • 您应该注意,一批中的元素不必在同一批次中彼此相关。这是您需要仔细考虑的事情。因为连续调用dynamic_rnn,输入序列中的批处理元素必须与前一个/后一个序列中的各自对应元素相关,但不能相互关联。即第一次调用中的元素3可能与同一批次中的其他127个元素无关,但NEXT调用中的元素3必须是PREVIOUS调用中元素3的时间/逻辑延续,依此类推。这样,你不断前进的状态是有意义的