理解ConvLSTM在tensorflow中的实现

时间:2020-12-28 00:20:18

标签: python tensorflow deep-learning lstm

tensorflow implementation of convLSTM cell 中,以下几行代码写成:

    x_i = self.input_conv(inputs_i, kernel_i, bias_i, padding=self.padding)
    x_f = self.input_conv(inputs_f, kernel_f, bias_f, padding=self.padding)
    x_c = self.input_conv(inputs_c, kernel_c, bias_c, padding=self.padding)
    x_o = self.input_conv(inputs_o, kernel_o, bias_o, padding=self.padding)
    h_i = self.recurrent_conv(h_tm1_i, recurrent_kernel_i)
    h_f = self.recurrent_conv(h_tm1_f, recurrent_kernel_f)
    h_c = self.recurrent_conv(h_tm1_c, recurrent_kernel_c)
    h_o = self.recurrent_conv(h_tm1_o, recurrent_kernel_o)

    i = self.recurrent_activation(x_i + h_i)
    f = self.recurrent_activation(x_f + h_f)
    c = f * c_tm1 + i * self.activation(x_c + h_c)
    o = self.recurrent_activation(x_o + h_o)
    h = o * self.activation(c)

paper 中描述的相应方程为:

enter image description here

我看不到 W_ci、W_cf、W_co C_{t-1}、C_t 在输入、遗忘和输出门中是如何使用的。它在哪里用于计算 4 个门?

1 个答案:

答案 0 :(得分:0)

当然你不能在 ConvLSTM 单元的实现中找到那些,因为它没有使用窥视孔:

<块引用>

窥孔连接允许门利用以前的内部 状态以及先前的隐藏状态(这就是 LSTMCell 限于)

tf.keras.experimental.PeepholeLSTMCell 遵循您在上面发布的等式,正如您在其中看到的 source code

x_i, x_f, x_c, x_o = x
h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
i = self.recurrent_activation(
    x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]) +
    self.input_gate_peephole_weights * c_tm1)
f = self.recurrent_activation(x_f + K.dot(
    h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]) +
                              self.forget_gate_peephole_weights * c_tm1)
c = f * c_tm1 + i * self.activation(x_c + K.dot(
    h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3]))
o = self.recurrent_activation(
    x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]) +
    self.output_gate_peephole_weights * c)

或者更清楚的是,如果您查看 source code 中的 tf.compat.v1.nn.rnn_cell.LSTMCell

if self._use_peepholes:
  c = (
      sigmoid(f + self._forget_bias + self._w_f_diag * c_prev) * c_prev +
      sigmoid(i + self._w_i_diag * c_prev) * self._activation(j))
else:
  c = (
      sigmoid(f + self._forget_bias) * c_prev +
      sigmoid(i) * self._activation(j))