使用图表集合存储RNN状态

时间:2017-09-20 08:29:28

标签: tensorflow

我经常使用import tensorflow as tf import numpy as np in_ = tf.placeholder(tf.float32, shape=[None, 5, 1]) batch_size = tf.shape(in_)[0] cell1 = tf.nn.rnn_cell.BasicLSTMCell(num_units=128) cell2 = tf.nn.rnn_cell.BasicLSTMCell(num_units=256) cell = tf.nn.rnn_cell.MultiRNNCell([cell1, cell2]) outputs, last_state = tf.nn.dynamic_rnn(cell=cell, inputs=in_, initial_state=cell.zero_state(batch_size, dtype=tf.float32)) tf.add_to_collection('states', last_state) loss = tf.reduce_mean(in_ - outputs) loss_s = tf.summary.scalar('loss', loss) writer = tf.summary.FileWriter('.', tf.get_default_graph()) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) l, s = sess.run([loss, loss_s], feed_dict={in_: np.ones([1, 5, 1])}) writer.add_summary(s) 让Tensorflow自动将中间结果序列化为检查点。当从检查点恢复模型时,我发现这是后来获取指向有趣张量指针的最方便的方法。但是,我意识到RNN状态元组不能轻易添加到图集合中。考虑TF 1.3中的以下虚拟示例:

WARNING:tensorflow:Error encountered when serializing states.
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'tuple' object has no attribute 'name'

这将产生以下警告:

last_state

似乎序列化无法处理元组,当然last_state变量是一个元组。也许可以循环遍历元组并将每个元素分别添加到集合中,但这看起来太复杂了。有什么更好的方法来处理这个问题?最后,我想在模型恢复时再次访问{{1}},理想情况下无需访问创建模型的原始代码。

1 个答案:

答案 0 :(得分:2)

实际上,循环遍历状态的每个元素并不太复杂,而且可以直接实现:

def add_to_collection_rnn_state(name, rnn_state):
    for layer in rnn_state:
        tf.add_to_collection(name, layer.c)
        tf.add_to_collection(name, layer.h)

然后加载它:

def get_collection_rnn_state(name):
    layers = []
    coll = tf.get_collection(name)
    for i in range(0, len(coll), 2):
        state = tf.nn.rnn_cell.LSTMStateTuple(coll[i], coll[i+1])
        layers.append(state)
    return tuple(layers)

请注意,这假设一个集合仅存储状态,即为您要存储的每个州使用不同的集合,例如:像这样:

add_to_collection_rnn_state('states', last_state)
add_to_collection_rnn_state('init_state', init_state)

修改

正如在评论中正确指出的那样,所提出的解决方案仅适用于LSTMCell(也表示为元组)。可以处理GRU单元或可能定制单元及其混合的更通用的解决方案可能如下所示:

import tensorflow as tf
import numpy as np

def add_to_collection_rnn_state(name, rnn_state):
    # store the name of each cell type in a different collection
    coll_of_names = name + '__names__'
    for layer in rnn_state:
        n = layer.__class__.__name__
        tf.add_to_collection(coll_of_names, n)
        try:
            for l in layer:
                tf.add_to_collection(name, l)
        except TypeError:
            # layer is not iterable so just add it directly
            tf.add_to_collection(name, layer)


def get_collection_rnn_state(name):
    layers = []
    coll = tf.get_collection(name)
    coll_of_names = tf.get_collection(name + '__names__')
    idx = 0
    for n in coll_of_names:
        if 'LSTMStateTuple' in n:
            state = tf.nn.rnn_cell.LSTMStateTuple(coll[idx], coll[idx+1])
            idx += 2
        else:  # add more cell types here
            state = coll[idx]
            idx += 1
        layers.append(state)
    return tuple(layers)


in_ = tf.placeholder(tf.float32, shape=[None, 5, 1])
batch_size = tf.shape(in_)[0]

cell1 = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
cell2 = tf.nn.rnn_cell.GRUCell(num_units=256)
cell3 = tf.nn.rnn_cell.BasicRNNCell(num_units=256)
cell = tf.nn.rnn_cell.MultiRNNCell([cell1, cell2, cell3])
outputs, last_state = tf.nn.dynamic_rnn(cell=cell,
                                        inputs=in_,
                                        initial_state=cell.zero_state(batch_size, dtype=tf.float32))

add_to_collection_rnn_state('last_state', last_state)
last_state_r = get_collection_rnn_state('last_state')

比较last_statelast_state_r会发现两者都是相同的(它们应该是相同的)。请注意,我使用不同的集合来存储名称,因为tensorflow只能在集合中的所有元素属于同一类型时序列化集合。例如。在同一个集合中使用Tensors混合字符串不起作用。