在Keras中是否有一种方法可以在给定输入的每个时间步上检索LSTM层的单元状态(即 c 矢量)?
似乎return_state
参数返回计算完成后的最后一个单元格状态,但我也需要中间的状态。另外,我不想将这些单元状态传递给下一层,我只希望能够访问它们。
最好使用TensorFlow作为后端。
谢谢
答案 0 :(得分:2)
我知道已经很晚了,希望对您有所帮助。
通过在调用方法中修改LSTM单元,从技术上讲您要问的是可能的。我对其进行了修改,并在您提供return_sequences=True
时使其返回4维而不是3维。
代码
from keras.layers.recurrent import _generate_dropout_mask
class Mod_LSTMCELL(LSTMCell):
def call(self, inputs, states, training=None):
if 0 < self.dropout < 1 and self._dropout_mask is None:
self._dropout_mask = _generate_dropout_mask(
K.ones_like(inputs),
self.dropout,
training=training,
count=4)
if (0 < self.recurrent_dropout < 1 and
self._recurrent_dropout_mask is None):
self._recurrent_dropout_mask = _generate_dropout_mask(
K.ones_like(states[0]),
self.recurrent_dropout,
training=training,
count=4)
# dropout matrices for input units
dp_mask = self._dropout_mask
# dropout matrices for recurrent units
rec_dp_mask = self._recurrent_dropout_mask
h_tm1 = states[0] # previous memory state
c_tm1 = states[1] # previous carry state
if self.implementation == 1:
if 0 < self.dropout < 1.:
inputs_i = inputs * dp_mask[0]
inputs_f = inputs * dp_mask[1]
inputs_c = inputs * dp_mask[2]
inputs_o = inputs * dp_mask[3]
else:
inputs_i = inputs
inputs_f = inputs
inputs_c = inputs
inputs_o = inputs
x_i = K.dot(inputs_i, self.kernel_i)
x_f = K.dot(inputs_f, self.kernel_f)
x_c = K.dot(inputs_c, self.kernel_c)
x_o = K.dot(inputs_o, self.kernel_o)
if self.use_bias:
x_i = K.bias_add(x_i, self.bias_i)
x_f = K.bias_add(x_f, self.bias_f)
x_c = K.bias_add(x_c, self.bias_c)
x_o = K.bias_add(x_o, self.bias_o)
if 0 < self.recurrent_dropout < 1.:
h_tm1_i = h_tm1 * rec_dp_mask[0]
h_tm1_f = h_tm1 * rec_dp_mask[1]
h_tm1_c = h_tm1 * rec_dp_mask[2]
h_tm1_o = h_tm1 * rec_dp_mask[3]
else:
h_tm1_i = h_tm1
h_tm1_f = h_tm1
h_tm1_c = h_tm1
h_tm1_o = h_tm1
i = self.recurrent_activation(x_i + K.dot(h_tm1_i,
self.recurrent_kernel_i))
f = self.recurrent_activation(x_f + K.dot(h_tm1_f,
self.recurrent_kernel_f))
c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c,
self.recurrent_kernel_c))
o = self.recurrent_activation(x_o + K.dot(h_tm1_o,
self.recurrent_kernel_o))
else:
if 0. < self.dropout < 1.:
inputs *= dp_mask[0]
z = K.dot(inputs, self.kernel)
if 0. < self.recurrent_dropout < 1.:
h_tm1 *= rec_dp_mask[0]
z += K.dot(h_tm1, self.recurrent_kernel)
if self.use_bias:
z = K.bias_add(z, self.bias)
z0 = z[:, :self.units]
z1 = z[:, self.units: 2 * self.units]
z2 = z[:, 2 * self.units: 3 * self.units]
z3 = z[:, 3 * self.units:]
i = self.recurrent_activation(z0)
f = self.recurrent_activation(z1)
c = f * c_tm1 + i * self.activation(z2)
o = self.recurrent_activation(z3)
h = o * self.activation(c)
if 0 < self.dropout + self.recurrent_dropout:
if training is None:
h._uses_learning_phase = True
return tf.expand_dims(tf.concat([h,c],axis=0),0), [h, c]
示例代码
# create a cell
test = Mod_LSTMCELL(100)
# Input timesteps=10, features=7
in1 = Input(shape=(10,7))
out1 = RNN(test, return_sequences=True)(in1)
M = Model(inputs=[in1],outputs=[out1])
M.compile(keras.optimizers.Adam(),loss='mse')
ans = M.predict(np.arange(7*10,dtype=np.float32).reshape(1, 10, 7))
print(ans.shape)
# state_h
print(ans[0,0,0,:])
# state_c
print(ans[0,0,1,:])
答案 1 :(得分:0)
首先,使用tf.keras.layers.LSTM不可能做到这一点。您必须改用LSTMCell或LSTM的子类。其次,无需继承LSTMCell即可获取单元状态的序列。每次调用LSTMCell时,它都已返回隐藏状态(h)和单元状态(c)的列表。 对于那些不熟悉LSTMCell的人,它将采用当前[h,c]张量,并在当前时间步长输入(它不能采用一系列时间),并返回激活和更新的[h,c]。 这是显示如何使用LSTMCell处理时间步序列并返回累积的单元状态的示例。
# example inputs
inputs = tf.convert_to_tensor(np.random.rand(3, 4), dtype='float32') # 3 timesteps, 4 features
h_c = [tf.zeros((1,2)), tf.zeros((1,2))] # must initialize hidden/cell state for lstm cell
h_c = tf.convert_to_tensor(h_c, dtype='float32')
lstm = tf.keras.layers.LSTMCell(2)
# example of how you accumulate cell state over repeated calls to LSTMCell
inputs = tf.unstack(inputs, axis=0)
c_states = []
for cur_inputs in inputs:
out, h_c = lstm(tf.expand_dims(cur_inputs, axis=0), h_c)
h, c = h_c
c_states.append(c)
答案 2 :(得分:-1)
您可以通过在初始化程序中设置<select ng-model= "province" aria-controls="example1" class="form-control form-control-sm" multiple="" ng-options=" x as value for (x, y) in optionData">
</select>
来访问任何RNN
的状态。您可以找到有关此here的更多信息。