我正在从这样的检查点加载/恢复模型:
ckpt_path = tf.train.latest_checkpoint(self.checkpoints_dir)
config = {'data_dir': os.path.dirname(self.vocab_filename), 'beam_size': 1, 'alpha': 0.6}
# Restore
graph = tf.Graph()
session = tf.Session(graph=graph, config=tf.ConfigProto(allow_soft_placement=True))
with graph.as_default():
_, self.input_node_name, self.output_node_name = load_translation_model(ckpt_path, config)
meta_graph = tf.train.import_meta_graph(ckpt_path + '.meta')
session.run(tf.global_variables_initializer())
meta_graph.restore(session, ckpt_path)
# Test restored model
encoder_inputs = self.encode('how are you doing')
sample = {'%s:0' % self.input_node_name: encoder_inputs}
output_node = graph.get_tensor_by_name('%s:0' % self.output_node_name)
result = session.run(output_node, feed_dict=sample)[0]
decoded = self.decode(result)
没有问题,除了decoded
输出只是垃圾。对我来说,好像模型无法正确还原,我正在随机初始化的变量上运行-可能是这种情况吗?
如果是这样,加载模型时我在做什么错了?