如何在每个纪元后重置tensorflow中的GRU状态

时间:2018-01-30 14:35:44

标签: tensorflow machine-learning deep-learning recurrent-neural-network gated-recurrent-unit

我正在使用tensorflow GRU单元来实现RNN。我正在使用前面提到的视频,最多5分钟。因此,由于下一个状态被自动馈送到GRU,如何在每个时期之后手动重置RNN的状态。换句话说,我希望训练开始时的初始状态始终为0.以下是我的代码的片段:

with tf.variable_scope('GRU'):
    latent_var = tf.reshape(latent_var, shape=[batch_size, time_steps, latent_dim])

    cell = tf.nn.rnn_cell.GRUCell(cell_size)   
    H, C = tf.nn.dynamic_rnn(cell, latent_var, dtype=tf.float32)  
    H = tf.reshape(H, [batch_size, cell_size]) 
....

非常感谢任何帮助!

1 个答案:

答案 0 :(得分:1)

使用tf.nn.dynamic_rnninitial_state参数:

  

initial_state :(可选)RNN的初始状态。如果   cell.state_size是一个整数,它必须是适当的张量   类型和形状[batch_size, cell.state_size]。如果cell.state_siz e是a   元组,这应该是具有[batch_size, s] for s in cell.state_size形状的张量元组。

文档中的改编示例:

# create a GRUCell
cell = tf.nn.rnn_cell.GRUCell(cell_size)

# 'outputs' is a tensor of shape [batch_size, max_time, cell_state_size]

# defining initial state
initial_state = cell.zero_state(batch_size, dtype=tf.float32)

# 'state' is a tensor of shape [batch_size, cell_state_size]
outputs, state = tf.nn.dynamic_rnn(cell, input_data,
                                   initial_state=initial_state,
                                   dtype=tf.float32)

另请注意,尽管initial_state不是占位符,但您也可以将值提供给它。因此,如果希望在一个纪元内保持状态,但在纪元开始时以零开始,你可以这样做:

# Compute the zero state array of the right shape once
zero_state = sess.run(initial_state)

# Start with a zero vector and update it 
cur_state = zero_state
for batch in get_batches():
  cur_state, _ = sess.run([state, ...], feed_dict={initial_state=cur_state, ...})