Tensorflow,如何在feed_dict中传递MultiRNN状态

时间:2016-10-14 10:19:57

标签: tensorflow

我正在尝试在tensorflow中制作一个生成RNN模型。令我讨厌的是,默认情况下,在RNN库中新切换到state_is_tupe是真的,我很难找到在批次之间保存状态的最佳方法。我知道我可以将其改回为假,但我不想这样做,因为它已被弃用。当我完成训练时,我需要能够在调用session.run之间保持隐藏状态,因为我将一次生成一个样本的序列。我发现我可以按如下方式返回rnn的状态。

ClusteringOrder

这会很棒但是当我想将state_output传递回模型时会出现问题。由于占位符只能是张量对象,因此我无法将其传递回state_output tupel。

我正在寻找一种非常通用的解决方案。 rnn可以是MultiRNNCell或单个LSTMCell或可想象的任何其他组合。

1 个答案:

答案 0 :(得分:0)

我想我明白了。我使用以下代码将状态元组展平为单个1D张量。当我根据rnn单元格的尺寸规格将其传回模型时,我可以把它砍掉。

def flatten_state_tupel(x):
    result = []
    for x_ in x:
        if isinstance(x_, tf.Tensor) or not hasattr(x_, '__iter__'):
            result.append(x_)
        else:
            result.extend(flatten_state_tupel(x_))
    return result

def pack_state_tupel(state):
    return tf.concat(0, [tf.reshape(s, (-1,)) for s in flatten_state_tupel(state)])

def unpack_state_tupel(state, size):
    state = tf.reshape(state, (-1, tf.reduce_sum(flatten_state_tupel(size))))
    def _make_state_tupel(sz, i):
        if hasattr(sz, '__iter__'):
            result = []
            for s in sz:
                base_index, y = _make_state_tupel(s, i)
                result.append(y)
            return base_index, tf.nn.rnn_cell.LSTMStateTuple(*result) if isinstance(sz, tf.nn.rnn_cell.LSTMStateTuple) else tuple(result)
        else:
            return i + sz, state[..., i : i + sz]
    return _make_state_tupel(size, 0)[-1]

我使用如下函数。

rnn = tf.nn.rnn_cell.MultiRNNCell(cells)  
zero_state = pack_state_tupel(rnn.zero_state(batch_size, tf.float32))
self.initial_state = tf.placeholder_with_default(zero_state, None)

output, final_state = tf.nn.dynamic_rnn(rnn, self.input_sound, initial_state = unpack_state_tupel(self.initial_state, rnn.state_size))

packed_state = pack_state_tupel(final_state)

sess = tf.Session()
sess.run(tf.initialize_all_variables())

state_output = sess.run(packed_state, feed_dict = {self.input_sound: np.zeros((64, 32, 512))})
print(state_output.shape)
state_output = sess.run(packed_state, feed_dict = {self.input_sound: np.zeros((64, 32, 512)), self.initial_state: np.zeros(state_output.shape[0])})
print(state_output)

这样,如果我没有传递任何东西(在训练期间就是这种情况),它将使状态归零但是我可以在代数之间保存并传递状态。