我经常使用import tensorflow as tf
import numpy as np
in_ = tf.placeholder(tf.float32, shape=[None, 5, 1])
batch_size = tf.shape(in_)[0]
cell1 = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
cell2 = tf.nn.rnn_cell.BasicLSTMCell(num_units=256)
cell = tf.nn.rnn_cell.MultiRNNCell([cell1, cell2])
outputs, last_state = tf.nn.dynamic_rnn(cell=cell,
inputs=in_,
initial_state=cell.zero_state(batch_size, dtype=tf.float32))
tf.add_to_collection('states', last_state)
loss = tf.reduce_mean(in_ - outputs)
loss_s = tf.summary.scalar('loss', loss)
writer = tf.summary.FileWriter('.', tf.get_default_graph())
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
l, s = sess.run([loss, loss_s], feed_dict={in_: np.ones([1, 5, 1])})
writer.add_summary(s)
让Tensorflow自动将中间结果序列化为检查点。当从检查点恢复模型时,我发现这是后来获取指向有趣张量指针的最方便的方法。但是,我意识到RNN状态元组不能轻易添加到图集合中。考虑TF 1.3中的以下虚拟示例:
WARNING:tensorflow:Error encountered when serializing states.
Type is unsupported, or the types of the items don't match field type in CollectionDef.
'tuple' object has no attribute 'name'
这将产生以下警告:
last_state
似乎序列化无法处理元组,当然last_state
变量是一个元组。也许可以循环遍历元组并将每个元素分别添加到集合中,但这看起来太复杂了。有什么更好的方法来处理这个问题?最后,我想在模型恢复时再次访问{{1}},理想情况下无需访问创建模型的原始代码。
答案 0 :(得分:2)
实际上,循环遍历状态的每个元素并不太复杂,而且可以直接实现:
def add_to_collection_rnn_state(name, rnn_state):
for layer in rnn_state:
tf.add_to_collection(name, layer.c)
tf.add_to_collection(name, layer.h)
然后加载它:
def get_collection_rnn_state(name):
layers = []
coll = tf.get_collection(name)
for i in range(0, len(coll), 2):
state = tf.nn.rnn_cell.LSTMStateTuple(coll[i], coll[i+1])
layers.append(state)
return tuple(layers)
请注意,这假设一个集合仅存储状态,即为您要存储的每个州使用不同的集合,例如:像这样:
add_to_collection_rnn_state('states', last_state)
add_to_collection_rnn_state('init_state', init_state)
修改强>
正如在评论中正确指出的那样,所提出的解决方案仅适用于LSTMCell(也表示为元组)。可以处理GRU单元或可能定制单元及其混合的更通用的解决方案可能如下所示:
import tensorflow as tf
import numpy as np
def add_to_collection_rnn_state(name, rnn_state):
# store the name of each cell type in a different collection
coll_of_names = name + '__names__'
for layer in rnn_state:
n = layer.__class__.__name__
tf.add_to_collection(coll_of_names, n)
try:
for l in layer:
tf.add_to_collection(name, l)
except TypeError:
# layer is not iterable so just add it directly
tf.add_to_collection(name, layer)
def get_collection_rnn_state(name):
layers = []
coll = tf.get_collection(name)
coll_of_names = tf.get_collection(name + '__names__')
idx = 0
for n in coll_of_names:
if 'LSTMStateTuple' in n:
state = tf.nn.rnn_cell.LSTMStateTuple(coll[idx], coll[idx+1])
idx += 2
else: # add more cell types here
state = coll[idx]
idx += 1
layers.append(state)
return tuple(layers)
in_ = tf.placeholder(tf.float32, shape=[None, 5, 1])
batch_size = tf.shape(in_)[0]
cell1 = tf.nn.rnn_cell.BasicLSTMCell(num_units=128)
cell2 = tf.nn.rnn_cell.GRUCell(num_units=256)
cell3 = tf.nn.rnn_cell.BasicRNNCell(num_units=256)
cell = tf.nn.rnn_cell.MultiRNNCell([cell1, cell2, cell3])
outputs, last_state = tf.nn.dynamic_rnn(cell=cell,
inputs=in_,
initial_state=cell.zero_state(batch_size, dtype=tf.float32))
add_to_collection_rnn_state('last_state', last_state)
last_state_r = get_collection_rnn_state('last_state')
比较last_state
和last_state_r
会发现两者都是相同的(它们应该是相同的)。请注意,我使用不同的集合来存储名称,因为tensorflow只能在集合中的所有元素属于同一类型时序列化集合。例如。在同一个集合中使用Tensors混合字符串不起作用。