我已经训练了一个seq2seq模型并将其保存为tf.Saver()
。
现在,我一直在玩的模型有这些占位符:
encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs')
encoder_inputs_length = tf.placeholder(shape=(None,), dtype=tf.int32, name='encoder_inputs_length')
decoder_targets = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_targets')
让我们关注decoder_targets
。如果我没记错的话,这个占位符是想要的结果。这是个问题。如果我想预测翻译(这意味着我没有“想要的结果”),我该怎么办?
这就是我目前的
import tensorflow as tf
tf.reset_default_graph()
saver = tf.train.Saver()
encoder_inputs = tf.placeholder(shape=(None, None), dtype=tf.int32, name='encoder_inputs')
encoder_inputs_length = tf.placeholder(shape=(None,), dtype=tf.int32, name='encoder_inputs_length')
decoder_targets = tf.placeholder(shape=(None, None), dtype=tf.int32, name='decoder_targets')
encoder_inputs_ = 0
encoder_inputs_length_
lol = 0
with tf.Session() as sess:
# Restore model
saver.restore(sess, "./rsc/model.ckpt")
# Init
sess.run(tf.initialize_all_variables())
# Test
sess.run(lol, feed_dict={encoder_inputs: encoder_inputs_, encoder_inputs_length: encoder_inputs_length_, ?????})
我是否也需要在此文件中编写整个模型? (对不起代码,我正在测试一些东西)