我编写了一个RNN,它以字符级别查看段落,并希望将其保存以供以后使用。一些代码如下:
cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
batch_size = tf.placeholder(tf.int32, [], name='batch_size')
multi_cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
init_state = multi_cell.zero_state(batch_size, dtype=tf.float32)
rnn_outputs, final_state = tf.nn.dynamic_rnn(multi_cell, rnn_inputs, initial_state=init_state)
with tf.variable_scope('softmax'):
W = tf.get_variable('W', [state_size, num_classes])
b = tf.get_variable('b', [num_classes], initializer=tf.constant_initializer(0.0))
rnn_outputs = tf.reshape(rnn_outputs, [-1, state_size])
y_reshaped = tf.reshape(y, [-1])
logits = tf.matmul(rnn_outputs, W) + b
predictions = tf.nn.softmax(logits, name="predictions")
total_loss = tf.reduce_mean(
tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=logits,
labels=y_reshaped
)
)
train_step = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)
然后我用tf.train.Saver()
和saver.save(sess, "path/to/save")
保存我的模型。
然后,我尝试在另一个脚本中加载模型并使用以下代码生成文本:
tf.reset_default_graph()
imported_meta = tf.train.import_meta_graph("path/to/save/save_file.meta")
with tf.Session() as sess:
imported_meta.restore(sess, tf.train.latest_checkpoint("path/to/save"))
x = sess.graph.get_tensor_by_name("input_placeholder:0")
batch_size_tensor = sess.graph.get_tensor_by_name("batch_size:0")
predictions = sess.graph.get_tensor_by_name("predictions:0")
state = None
current_char = vocab_to_index[start_char]
for i in range(num_chars):
if state is not None:
feed_dict={batch_size_tensor: batch_size, x: [[current_char]], init_state: state}
else:
feed_dict={batch_size_tensor: batch_size, x: [[current_char]]}
rnn_outputs, state = sess.run(
[predictions, final_state],
feed_dict
)
基本上,我想在这里做的是输入一个字符,然后根据前一个字符再生成一个字符。在初始字符之后,dynamic_rnn中的final_state
应该为sess.run()
,并作为init_state
进入下一个生成过程。但是,我找不到将训练代码中定义的init_state
和final_state
保存到测试代码中的方法,对于这些操作,没有像tf.nn.softmax
那样的“ name”参数
我想要的是一些类似final_state = sess.graph.get_operation_by_name('final_state')
的代码,以便我可以sess.run(final_state)
并将其作为init_state
反馈。
我尝试在训练代码和tf.add_to_collection("some_name", final_state)
中使用tf.get_collection("some_name")
,但是错误提示在测试图中找不到集合“ some_name”。
编写文本生成模型的人在生成阶段遇到了这个问题吗?或者人们如何生成文本/保存并加载来自dynamic_rnn的输出?
非常感谢!