Tensorflow-关于从受过训练的RNN传递操作并生成文本的问题

时间:2019-01-02 02:59:27

标签: python tensorflow recurrent-neural-network

我编写了一个RNN,它以字符级别查看段落,并希望将其保存以供以后使用。一些代码如下:

cell = tf.nn.rnn_cell.LSTMCell(state_size, state_is_tuple=True)
batch_size = tf.placeholder(tf.int32, [], name='batch_size')
multi_cell = tf.nn.rnn_cell.MultiRNNCell([cell] * num_layers, state_is_tuple=True)
init_state = multi_cell.zero_state(batch_size, dtype=tf.float32)
rnn_outputs, final_state = tf.nn.dynamic_rnn(multi_cell, rnn_inputs, initial_state=init_state)

with tf.variable_scope('softmax'):
    W = tf.get_variable('W', [state_size, num_classes])
    b = tf.get_variable('b', [num_classes], initializer=tf.constant_initializer(0.0))

rnn_outputs = tf.reshape(rnn_outputs, [-1, state_size])
y_reshaped = tf.reshape(y, [-1])

logits = tf.matmul(rnn_outputs, W) + b
predictions = tf.nn.softmax(logits, name="predictions")

total_loss = tf.reduce_mean(
    tf.nn.sparse_softmax_cross_entropy_with_logits(
        logits=logits, 
        labels=y_reshaped
    )
)
train_step = tf.train.AdamOptimizer(learning_rate).minimize(total_loss)

然后我用tf.train.Saver()saver.save(sess, "path/to/save")保存我的模型。

然后,我尝试在另一个脚本中加载模型并使用以下代码生成文本:

tf.reset_default_graph()
imported_meta = tf.train.import_meta_graph("path/to/save/save_file.meta")

with tf.Session() as sess:
    imported_meta.restore(sess, tf.train.latest_checkpoint("path/to/save"))
    x = sess.graph.get_tensor_by_name("input_placeholder:0")
    batch_size_tensor = sess.graph.get_tensor_by_name("batch_size:0")
    predictions = sess.graph.get_tensor_by_name("predictions:0")

    state = None
    current_char = vocab_to_index[start_char]

    for i in range(num_chars):
        if state is not None:
            feed_dict={batch_size_tensor: batch_size, x: [[current_char]], init_state: state}
        else:
            feed_dict={batch_size_tensor: batch_size, x: [[current_char]]}

        rnn_outputs, state = sess.run(
            [predictions, final_state], 
            feed_dict
        )

基本上,我想在这里做的是输入一个字符,然后根据前一个字符再生成一个字符。在初始字符之后,dynamic_rnn中的final_state应该为sess.run(),并作为init_state进入下一个生成过程。但是,我找不到将训练代码中定义的init_statefinal_state保存到测试代码中的方法,对于这些操作,没有像tf.nn.softmax那样的“ name”参数

我想要的是一些类似final_state = sess.graph.get_operation_by_name('final_state')的代码,以便我可以sess.run(final_state)并将其作为init_state反馈。

我尝试在训练代码和tf.add_to_collection("some_name", final_state)中使用tf.get_collection("some_name"),但是错误提示在测试图中找不到集合“ some_name”。

编写文本生成模型的人在生成阶段遇到了这个问题吗?或者人们如何生成文本/保存并加载来自dynamic_rnn的输出?

非常感谢!

0 个答案:

没有答案