无法重新保存张量流会话而不会再次保存

时间:2017-04-19 19:10:58

标签: python tensorflow

以下是代码:

import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
import tensorflow as tf


# print(os.getcwd())
# os.chdir(os.getcwd())
# os.chdir("/tmp")


chk_file = "hello.chk"
def save(checkpoint_file=chk_file):
    with tf.Session() as session:
        x = tf.Variable(initial_value=[1, 2, 3], name="x")
        y = tf.Variable(initial_value=[[1.0, 2.0], [3.0, 4.0]], name="y")
        session.run(tf.global_variables_initializer())
        print(session.run(tf.global_variables()))
        saver = tf.train.Saver()
        save_path = saver.save(sess=session, save_path=checkpoint_file)
        print(session.run(tf.global_variables()))



def restore(checkpoint_file=chk_file):
    with tf.Session() as session:
        saver = tf.train.Saver()
        saver.restore(sess=session, save_path=checkpoint_file)
        print(session.run(tf.global_variables()[0]))
        print(tf.global_variables()[0])
        # print(session.run(tf.get_variable("x", shape=(3, ))))

def reset():
    tf.reset_default_graph()

path = save()
# print(path)
restore("/home/kaiyin/PycharmProjects/text-classify/hello.chk")

一些问题:

  • restore(path)不起作用,与saver.restore

  • 文档中描述的内容相反
  • 相对路径对restore不起作用,即使您已经在正确的目录中

  • 如果您注释掉path = save()行,则会收到错误消息:

    / home / kaiyin / virtualenvs / tensorflow / bin / python /home/kaiyin/PycharmProjects/text-classify/restore.py Traceback(最近一次调用最后一次):   文件" /home/kaiyin/PycharmProjects/text-classify/restore.py" ;,第38行,在     恢复(" /home/kaiyin/PycharmProjects/text-classify/hello.chk")   文件" /home/kaiyin/PycharmProjects/text-classify/restore.py",第27行,还原     saver = tf.train.Saver()   文件" /home/kaiyin/virtualenvs/tensorflow/lib/python3.5/site-packages/tensorflow/python/training/saver.py" ;,第1040行, init     self.build()   文件" /home/kaiyin/virtualenvs/tensorflow/lib/python3.5/site-packages/tensorflow/python/training/saver.py" ;,第1061行,在构建中     提出ValueError("没有要保存的变量") ValueError:没有要保存的变量

    流程已完成退出代码1

我可以忍受前两个问题,但第三个问题是一个真正的阻碍者。为什么我需要保存会话每次我想恢复它?由于没有全局会话对象,save函数如何产生这样的影响也有点神秘。

Tensorflow版本:1.0.1

Python 3.5.2

Ubuntu 16.04

1 个答案:

答案 0 :(得分:1)

我整天都很头疼。

刚刚解决了:

正确的方法是在调用restore()

之前需要重新初始化所有变量

例如,在cifar10项目中(cifar10.py - 第188行)

如果您想恢复以前保存的变量

首先需要调用inference()来初始化所有变量 然后调用restore()。

<强>重新初始化

def restore(checkpoint_file=chk_file):
    with tf.Session() as session:
        x = tf.Variable(initial_value=[1, 2, 3], name="x")
        y = tf.Variable(initial_value=[[1.0, 2.0], [3.0, 4.0]], name="y")
        saver = tf.train.Saver()
        saver.restore(sess=session, save_path=checkpoint_file)
        print(session.run(tf.global_variables()[1]))
        print(tf.global_variables()[0])
        # print(session.run(tf.get_variable("x", shape=(3, ))))

def reset():
    tf.reset_default_graph()

restore("/home/kaiyin/PycharmProjects/text-classify/hello.chk")