Tensorflow ValueError:没有要保存的变量(但有很多)

时间:2018-05-29 00:01:36

标签: tensorflow

我正在尝试从已从磁盘加载的图表中保存模型。我可以加载图表并毫无问题地检查它,然后运行训练操作,但是我无法创建一个保护程序而没有得到ValueError:没有要保存的变量。

图表定义:

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from pathlib import Path
import os
import tensorflow as tf

outdir = os.path.dirname(__file__)
outfile = Path(__file__).stem + ".pb"

print(os.path.join(outdir, outfile))

# The input is the state of a Tic Tac Toe game.
# This is represented as two length-9 Vec<u8>.
# The first plane holds the location of the first player's stones,
# The second plane, the second player's.
# A 19th byte holds 0 for first player, 1 for second player.
x = tf.placeholder(tf.uint8, shape=[None, 9 * 2 + 1], name ='x')

# Training makes makes the net more likely to pick the picked move.
# The picked move will be 1.0, the other 8 spaces will be 0.0.
y_true = tf.placeholder(tf.float32, shape=[None, 9], name='y_true')

dense = tf.layers.dense(tf.cast(x, tf.float32), units=64, activation=tf.nn.relu)
logits = tf.layers.dense(dense, units=9, activation=tf.nn.relu)
softmax = tf.nn.softmax(logits, name='softmax')

sess = tf.Session()
init = tf.variables_initializer(tf.global_variables(), name='init')
sess.run(init)

loss = tf.losses.mean_squared_error(labels=y_true, predictions=softmax)
optimizer = tf.train.GradientDescentOptimizer(.01)
train = optimizer.minimize(loss, name='train')

definition = tf.Session().graph_def
tf.train.write_graph(definition, outdir, outfile, as_text=False)

加载图表:

import tensorflow as tf
import glob

num_epochs = 100
minibatch_size = 128
dataset_dir = "src/tictactoe/gamedata"
model_dir = "src/tictactoe/simple_model/checkpoint"
graph_filename = "src/tictactoe/simple_net.pb"

def make_dataset(num_epochs, minibatch_size, dataset_dir):
    files = glob.glob("{}/*.tfrecord".format(dataset_dir))
    print("loading", files)
    dataset = tf.data.TFRecordDataset(files)
    dataset = dataset.map(parse)
    dataset = dataset.shuffle(buffer_size=100000)
    dataset = dataset.batch(minibatch_size)
    print("loaded data")
    return dataset

def parse(bytes):
  features = {"game": tf.FixedLenFeature((), tf.string),
              "choice": tf.FixedLenSequenceFeature((), tf.float32, allow_missing=True)}
  parsed_features = tf.parse_single_example(bytes, features)
  game = tf.decode_raw(parsed_features["game"], tf.uint8)
  choice =  parsed_features["choice"]
  return tf.reshape(game, [19]), tf.reshape(choice, [9])


with tf.gfile.FastGFile(graph_filename,'rb') as f:
    sess = tf.InteractiveSession()

    dataset = make_dataset(num_epochs, minibatch_size, dataset_dir)
    print("loading graph at '{}'".format(graph_filename))

    iterator = dataset.make_initializable_iterator()
    example, label = iterator.get_next()
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    tf.import_graph_def(graph_def, name='',input_map={'x': example, 'y_true':label})

    init = tf.group(
        tf.global_variables_initializer(), 
        tf.local_variables_initializer(), 
        iterator.initializer, 
        sess.graph.get_operation_by_name('init'))

    train = sess.graph.get_operation_by_name('train')
    for name in [n.name for n in tf.get_default_graph().as_graph_def().node]:
        print(name)

    saver = tf.train.Saver()

    sess.run(init)

    for i in range(num_epochs):
        sess.run(iterator.initializer)

        while True:
            try:
                sess.run(train)
            except tf.errors.OutOfRangeError:
                break
        save_path = saver.save(sess, model_dir)
        print("Model saved in path: %s" % save_path)

Tensorflow在saver = tf.train.Saver()

投掷

我尝试确认图表已正确恢复,并且通过在保护程序行上方的默认图表中打印出所有变量,将其包含的变量加载到当前默认图表中。那里有数百个,包括我在图表创建文件(x,y_true,train等)中手工命名的那些。

相关问题似乎不是我的问题。例如,我发现的最相关的问题是: No variable to save error in Tensorflow

OP的问题是他的变量出现在错误的图表中。对我来说,只有一个图表,它肯定包含变量。

1 个答案:

答案 0 :(得分:1)

如果你想让tensorflow识别变量,你需要导入元图; graphdef本身没有足够的信息来重建一切。查看tf.train.import_meta_graph的文档。