(我现在正在使用TF 1.7,如果重要的话。)
我正在尝试初始化然后将模型和关联的元图保存在一个脚本(init.py
)中,以便我可以加载模型并从第二个脚本(train.py
)恢复培训。该模型使用占位符进行初始化,用于培训示例和标签,在培训期间将替换为真实的张量。然而,当我尝试在train.py
(从数据集)中创建一些真正的张量时,我得到一个堆栈跟踪,表明我的迭代器尚未初始化。跟踪指向import_meta_graph()
调用,无论我使用的是onehot迭代器(不应该要求初始化),还是我 实际初始化的可初始化迭代器,都会发生相同的事情。
从概念上讲,我是否错过了两个图形如何拼接在一起的东西?
我想相信这是保存和恢复元图的常见用例,但我在互联网上找不到它的任何例子。其他人如何将他们的真实数据提供给恢复的模型?
Caused by op 'IteratorGetNext_1', defined at:
File "src/tictactoe/train.py", line 47, in <module>
meta_graph, input_map={'example': example, 'label': label})
File "/home/mason/dev/rust/seraphim/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1927, in import_meta_graph
**kwargs)
File "/home/mason/dev/rust/seraphim/lib/python3.5/site-packages/tensorflow/python/framework/meta_graph.py", line 741, in import_scoped_meta_graph
producer_op_list=producer_op_list)
File "/home/mason/dev/rust/seraphim/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py", line 432, in new_func
return func(*args, **kwargs)
File "/home/mason/dev/rust/seraphim/lib/python3.5/site-packages/tensorflow/python/framework/importer.py", line 577, in import_graph_def
op_def=op_def)
File "/home/mason/dev/rust/seraphim/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3290, in create_op
op_def=op_def)
File "/home/mason/dev/rust/seraphim/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1654, in __init__
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
FailedPreconditionError (see above for traceback): GetNext() failed because theiterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element.
[[Node: IteratorGetNext_1 = IteratorGetNext[output_shapes=[[?,19], [?,9]], output_types=[DT_UINT8, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator_1)]]
这两个脚本的完整代码:
# init.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from pathlib import Path
import argparse
import os
import tensorflow as tf
parser = argparse.ArgumentParser(description='Initialize a TicTacToe expert model.')
parser.add_argument('name', metavar='foo-model', help='Model prefix')
args = parser.parse_args()
model_dir = "src/tictactoe/saved_models/" + args.name + "/" + args.name
with tf.Session() as sess:
example = tf.placeholder(tf.uint8, shape=[1, 9 * 2 + 1], name ='example')
label = tf.placeholder(tf.float32, shape=[1, 9], name='label')
dense = tf.layers.dense(tf.cast(example, tf.float32), units=64, activation=tf.nn.relu)
logits = tf.layers.dense(dense, units=9, activation=tf.nn.relu)
softmax = tf.nn.softmax(logits, name='softmax')
tf.add_to_collection('softmax', softmax)
sess = tf.Session()
init = tf.group(
tf.global_variables_initializer(),
tf.local_variables_initializer())
sess.run(init)
loss = tf.losses.mean_squared_error(labels=label, predictions=softmax)
optimizer = tf.train.GradientDescentOptimizer(.01)
train = optimizer.minimize(loss, name='train')
tf.add_to_collection('train', train)
saver = tf.train.Saver()
saved = saver.save(sess, model_dir, global_step=0)
print("Model saved in path: %s" % saved)
这是训练脚本。
# train.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from pathlib import Path
import argparse
import glob
import os
import tensorflow as tf
parser = argparse.ArgumentParser(description='Initialize a TicTacToe expert model.')
parser.add_argument('name', metavar='foo-model', help='Model prefix')
args = parser.parse_args()
model_dir = "src/tictactoe/saved_models/" + args.name
saver_prefix = "src/tictactoe/saved_models/" + args.name + "/" + args.name
latest_checkpoint = tf.train.latest_checkpoint(model_dir)
meta_graph = ".".join([latest_checkpoint, "meta"])
num_epochs = 100
minibatch_size = 128
dataset_dir = "src/tictactoe/gamedata"
def make_dataset(minibatch_size, dataset_dir):
files = glob.glob("{}/*.tfrecord".format(dataset_dir))
print(files)
dataset = tf.data.TFRecordDataset(files)
dataset = dataset.map(parse)
dataset = dataset.shuffle(buffer_size=100000)
dataset = dataset.batch(minibatch_size)
return dataset
def parse(bytes):
features = {"game": tf.FixedLenFeature((), tf.string),
"choice": tf.FixedLenSequenceFeature((), tf.float32, allow_missing=True)}
parsed_features = tf.parse_single_example(bytes, features)
game = tf.decode_raw(parsed_features["game"], tf.uint8)
choice = parsed_features["choice"]
return tf.reshape(game, [19]), tf.reshape(choice, [9])
with tf.Session() as sess:
dataset = make_dataset(minibatch_size, dataset_dir)
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer)
example, label = iterator.get_next()
saver = tf.train.import_meta_graph(
meta_graph, input_map={'example': example, 'label': label})
print("{}".format(meta_graph))
saver.restore(sess, latest_checkpoint)
print("{}".format(latest_checkpoint))
train_op = tf.get_collection('train_op')[0]
for i in range(num_epochs):
sess.run(iterator.initializer)
while True:
try:
sess.run(train_op)
except tf.errors.OutOfRangeError:
break
print(saver.save(sess, saver_prefix, global_step=step))
答案 0 :(得分:0)
我相信我已经找到了这个问题。问题是train.py
中的我的Saver正在保存我已映射的实际输入张量。当我尝试恢复时,那些真正的输入张量从磁盘恢复,但未初始化。
所以:在运行input.py
一次后,以下train.py
脚本成功训练。但是当我再次运行时,它被映射到图中的额外输入张量被恢复但未初始化。这有点奇怪,因为我在恢复时将它们再次映射出来,所以我认为没有必要对它们进行初始化。我发现tf.report_uninitialized_variables()
对于调试问题至关重要。
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from pathlib import Path
import argparse
import glob
import os
import tensorflow as tf
parser = argparse.ArgumentParser(description='Initialize a TicTacToe expert model.')
parser.add_argument('name', metavar='foo-model', help='Model prefix')
args = parser.parse_args()
model_dir = "src/tictactoe/saved_models/" + args.name
saver_prefix = "src/tictactoe/saved_models/" + args.name + "/" + args.name
latest_checkpoint = tf.train.latest_checkpoint(model_dir)
meta_graph = ".".join([latest_checkpoint, "meta"])
num_epochs = 100
minibatch_size = 128
dataset_dir = "src/tictactoe/gamedata"
def make_dataset(minibatch_size, dataset_dir):
files = glob.glob("{}/*.tfrecord".format(dataset_dir))
print(files)
dataset = tf.data.TFRecordDataset(files)
dataset = dataset.map(parse)
dataset = dataset.shuffle(buffer_size=100000)
dataset = dataset.batch(minibatch_size)
return dataset
def parse(bytes):
features = {"game": tf.FixedLenFeature((), tf.string),
"choice": tf.FixedLenSequenceFeature((), tf.float32, allow_missing=True)}
parsed_features = tf.parse_single_example(bytes, features)
game = tf.decode_raw(parsed_features["game"], tf.uint8)
choice = parsed_features["choice"]
return tf.reshape(game, [19]), tf.reshape(choice, [9])
with tf.Session() as sess:
dataset = make_dataset(minibatch_size, dataset_dir)
iterator = dataset.make_initializable_iterator()
example, label = iterator.get_next()
# print("before iterator", sess.run(tf.report_uninitialized_variables()))
saver = tf.train.import_meta_graph(meta_graph, input_map={'example': example, 'label': label})
print("{}".format(meta_graph))
saver.restore(sess, latest_checkpoint)
print("{}".format(latest_checkpoint))
train_op = tf.get_collection('train_op')[0]
init = tf.get_collection('init')[0]
for i in range(num_epochs):
sess.run(iterator.initializer)
while True:
try:
sess.run(train_op)
except tf.errors.OutOfRangeError:
break
print(saver.save(sess, saver_prefix))