我正在尝试从已从磁盘加载的图表中保存模型。我可以加载图表并毫无问题地检查它,然后运行训练操作,但是我无法创建一个保护程序而没有得到ValueError:没有要保存的变量。
图表定义:
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from pathlib import Path
import os
import tensorflow as tf
outdir = os.path.dirname(__file__)
outfile = Path(__file__).stem + ".pb"
print(os.path.join(outdir, outfile))
# The input is the state of a Tic Tac Toe game.
# This is represented as two length-9 Vec<u8>.
# The first plane holds the location of the first player's stones,
# The second plane, the second player's.
# A 19th byte holds 0 for first player, 1 for second player.
x = tf.placeholder(tf.uint8, shape=[None, 9 * 2 + 1], name ='x')
# Training makes makes the net more likely to pick the picked move.
# The picked move will be 1.0, the other 8 spaces will be 0.0.
y_true = tf.placeholder(tf.float32, shape=[None, 9], name='y_true')
dense = tf.layers.dense(tf.cast(x, tf.float32), units=64, activation=tf.nn.relu)
logits = tf.layers.dense(dense, units=9, activation=tf.nn.relu)
softmax = tf.nn.softmax(logits, name='softmax')
sess = tf.Session()
init = tf.variables_initializer(tf.global_variables(), name='init')
sess.run(init)
loss = tf.losses.mean_squared_error(labels=y_true, predictions=softmax)
optimizer = tf.train.GradientDescentOptimizer(.01)
train = optimizer.minimize(loss, name='train')
definition = tf.Session().graph_def
tf.train.write_graph(definition, outdir, outfile, as_text=False)
加载图表:
import tensorflow as tf
import glob
num_epochs = 100
minibatch_size = 128
dataset_dir = "src/tictactoe/gamedata"
model_dir = "src/tictactoe/simple_model/checkpoint"
graph_filename = "src/tictactoe/simple_net.pb"
def make_dataset(num_epochs, minibatch_size, dataset_dir):
files = glob.glob("{}/*.tfrecord".format(dataset_dir))
print("loading", files)
dataset = tf.data.TFRecordDataset(files)
dataset = dataset.map(parse)
dataset = dataset.shuffle(buffer_size=100000)
dataset = dataset.batch(minibatch_size)
print("loaded data")
return dataset
def parse(bytes):
features = {"game": tf.FixedLenFeature((), tf.string),
"choice": tf.FixedLenSequenceFeature((), tf.float32, allow_missing=True)}
parsed_features = tf.parse_single_example(bytes, features)
game = tf.decode_raw(parsed_features["game"], tf.uint8)
choice = parsed_features["choice"]
return tf.reshape(game, [19]), tf.reshape(choice, [9])
with tf.gfile.FastGFile(graph_filename,'rb') as f:
sess = tf.InteractiveSession()
dataset = make_dataset(num_epochs, minibatch_size, dataset_dir)
print("loading graph at '{}'".format(graph_filename))
iterator = dataset.make_initializable_iterator()
example, label = iterator.get_next()
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='',input_map={'x': example, 'y_true':label})
init = tf.group(
tf.global_variables_initializer(),
tf.local_variables_initializer(),
iterator.initializer,
sess.graph.get_operation_by_name('init'))
train = sess.graph.get_operation_by_name('train')
for name in [n.name for n in tf.get_default_graph().as_graph_def().node]:
print(name)
saver = tf.train.Saver()
sess.run(init)
for i in range(num_epochs):
sess.run(iterator.initializer)
while True:
try:
sess.run(train)
except tf.errors.OutOfRangeError:
break
save_path = saver.save(sess, model_dir)
print("Model saved in path: %s" % save_path)
Tensorflow在saver = tf.train.Saver()
我尝试确认图表已正确恢复,并且通过在保护程序行上方的默认图表中打印出所有变量,将其包含的变量加载到当前默认图表中。那里有数百个,包括我在图表创建文件(x,y_true,train等)中手工命名的那些。
相关问题似乎不是我的问题。例如,我发现的最相关的问题是: No variable to save error in Tensorflow
OP的问题是他的变量出现在错误的图表中。对我来说,只有一个图表,它肯定包含变量。
答案 0 :(得分:1)
如果你想让tensorflow识别变量,你需要导入元图; graphdef本身没有足够的信息来重建一切。查看tf.train.import_meta_graph
的文档。