Tensorflow:在恢复的Metagraph中用真实张量替换占位符

时间:2018-06-17 05:09:32

标签: tensorflow tensorflow-datasets

(我现在正在使用TF 1.7,如果重要的话。)

我正在尝试初始化然后将模型和关联的元图保存在一个脚本(init.py)中,以便我可以加载模型并从第二个脚本(train.py)恢复培训。该模型使用占位符进行初始化,用于培训示例和标签,在培训期间将替换为真实的张量。然而,当我尝试在train.py(从数据集)中创建一些真正的张量时,我得到一个堆栈跟踪,表明我的迭代器尚未初始化。跟踪指向import_meta_graph()调用,无论我使用的是onehot迭代器(不应该要求初始化),还是我 实际初始化的可初始化迭代器,都会发生相同的事情。

从概念上讲,我是否错过了两个图形如何拼接在一起的东西?

我想相信这是保存和恢复元图的常见用例,但我在互联网上找不到它的任何例子。其他人如何将他们的真实数据提供给恢复的模型?

Caused by op 'IteratorGetNext_1', defined at:
  File "src/tictactoe/train.py", line 47, in <module>
    meta_graph, input_map={'example': example, 'label': label})
  File "/home/mason/dev/rust/seraphim/lib/python3.5/site-packages/tensorflow/python/training/saver.py", line 1927, in import_meta_graph
    **kwargs)
  File "/home/mason/dev/rust/seraphim/lib/python3.5/site-packages/tensorflow/python/framework/meta_graph.py", line 741, in import_scoped_meta_graph
    producer_op_list=producer_op_list)
  File "/home/mason/dev/rust/seraphim/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py", line 432, in new_func
    return func(*args, **kwargs)
  File "/home/mason/dev/rust/seraphim/lib/python3.5/site-packages/tensorflow/python/framework/importer.py", line 577, in import_graph_def
    op_def=op_def)
  File "/home/mason/dev/rust/seraphim/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 3290, in create_op
    op_def=op_def)
  File "/home/mason/dev/rust/seraphim/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1654, in __init__
    self._traceback = self._graph._extract_stack()  # pylint: disable=protected-access

FailedPreconditionError (see above for traceback): GetNext() failed because theiterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element.
         [[Node: IteratorGetNext_1 = IteratorGetNext[output_shapes=[[?,19], [?,9]], output_types=[DT_UINT8, DT_FLOAT], _device="/job:localhost/replica:0/task:0/device:CPU:0"](Iterator_1)]]

这两个脚本的完整代码:

# init.py

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from pathlib import Path

import argparse
import os
import tensorflow as tf

parser = argparse.ArgumentParser(description='Initialize a TicTacToe expert model.')
parser.add_argument('name', metavar='foo-model', help='Model prefix')
args = parser.parse_args()
model_dir = "src/tictactoe/saved_models/" + args.name + "/" + args.name

with tf.Session() as sess:
    example = tf.placeholder(tf.uint8, shape=[1, 9 * 2 + 1], name ='example')
    label = tf.placeholder(tf.float32, shape=[1, 9], name='label')
    dense = tf.layers.dense(tf.cast(example, tf.float32), units=64, activation=tf.nn.relu)
    logits = tf.layers.dense(dense, units=9, activation=tf.nn.relu)
    softmax = tf.nn.softmax(logits, name='softmax')
    tf.add_to_collection('softmax', softmax)

    sess = tf.Session()
    init = tf.group(
        tf.global_variables_initializer(), 
        tf.local_variables_initializer())
    sess.run(init)

    loss = tf.losses.mean_squared_error(labels=label, predictions=softmax)
    optimizer = tf.train.GradientDescentOptimizer(.01)
    train = optimizer.minimize(loss, name='train')
    tf.add_to_collection('train', train)

    saver = tf.train.Saver()
    saved = saver.save(sess, model_dir, global_step=0)
    print("Model saved in path: %s" % saved)

这是训练脚本。

# train.py
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from pathlib import Path
import argparse
import glob
import os
import tensorflow as tf

parser = argparse.ArgumentParser(description='Initialize a TicTacToe expert model.')
parser.add_argument('name', metavar='foo-model', help='Model prefix')
args = parser.parse_args()

model_dir = "src/tictactoe/saved_models/" + args.name 
saver_prefix = "src/tictactoe/saved_models/" + args.name + "/" + args.name

latest_checkpoint = tf.train.latest_checkpoint(model_dir)
meta_graph = ".".join([latest_checkpoint, "meta"])

num_epochs = 100
minibatch_size = 128
dataset_dir = "src/tictactoe/gamedata"
def make_dataset(minibatch_size, dataset_dir):
    files = glob.glob("{}/*.tfrecord".format(dataset_dir))
    print(files)
    dataset = tf.data.TFRecordDataset(files)
    dataset = dataset.map(parse)
    dataset = dataset.shuffle(buffer_size=100000)
    dataset = dataset.batch(minibatch_size)
    return dataset

def parse(bytes):
  features = {"game": tf.FixedLenFeature((), tf.string),
              "choice": tf.FixedLenSequenceFeature((), tf.float32, allow_missing=True)}
  parsed_features = tf.parse_single_example(bytes, features)
  game = tf.decode_raw(parsed_features["game"], tf.uint8)
  choice =  parsed_features["choice"]
  return tf.reshape(game, [19]), tf.reshape(choice, [9])

with tf.Session() as sess:
    dataset = make_dataset(minibatch_size, dataset_dir)
    iterator = dataset.make_initializable_iterator()
    sess.run(iterator.initializer)
    example, label = iterator.get_next()    

    saver = tf.train.import_meta_graph(
        meta_graph, input_map={'example': example, 'label': label})
    print("{}".format(meta_graph))
    saver.restore(sess, latest_checkpoint)
    print("{}".format(latest_checkpoint))
    train_op = tf.get_collection('train_op')[0]

    for i in range(num_epochs):
        sess.run(iterator.initializer)
        while True:
            try:
                sess.run(train_op)
            except tf.errors.OutOfRangeError:
                break
            print(saver.save(sess, saver_prefix, global_step=step))

1 个答案:

答案 0 :(得分:0)

我相信我已经找到了这个问题。问题是train.py中的我的Saver正在保存我已映射的实际输入张量。当我尝试恢复时,那些真正的输入张量从磁盘恢复,但未初始化。

所以:在运行input.py一次后,以下train.py脚本成功训练。但是当我再次运行时,它被映射到图中的额外输入张量被恢复但未初始化。这有点奇怪,因为我在恢复时将它们再次映射出来,所以我认为没有必要对它们进行初始化。我发现tf.report_uninitialized_variables()对于调试问题至关重要。

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

from pathlib import Path
import argparse
import glob
import os
import tensorflow as tf

parser = argparse.ArgumentParser(description='Initialize a TicTacToe expert model.')
parser.add_argument('name', metavar='foo-model', help='Model prefix')
args = parser.parse_args()

model_dir = "src/tictactoe/saved_models/" + args.name 
saver_prefix = "src/tictactoe/saved_models/" + args.name + "/" + args.name

latest_checkpoint = tf.train.latest_checkpoint(model_dir)
meta_graph = ".".join([latest_checkpoint, "meta"])

num_epochs = 100
minibatch_size = 128
dataset_dir = "src/tictactoe/gamedata"
def make_dataset(minibatch_size, dataset_dir):
    files = glob.glob("{}/*.tfrecord".format(dataset_dir))
    print(files)
    dataset = tf.data.TFRecordDataset(files)
    dataset = dataset.map(parse)
    dataset = dataset.shuffle(buffer_size=100000)
    dataset = dataset.batch(minibatch_size)
    return dataset

def parse(bytes):
  features = {"game": tf.FixedLenFeature((), tf.string),
          "choice": tf.FixedLenSequenceFeature((), tf.float32, allow_missing=True)}


parsed_features = tf.parse_single_example(bytes, features)
  game = tf.decode_raw(parsed_features["game"], tf.uint8)
  choice =  parsed_features["choice"]
  return tf.reshape(game, [19]), tf.reshape(choice, [9])

with tf.Session() as sess:
    dataset = make_dataset(minibatch_size, dataset_dir)
    iterator = dataset.make_initializable_iterator()
    example, label = iterator.get_next()    
    # print("before iterator", sess.run(tf.report_uninitialized_variables()))

    saver = tf.train.import_meta_graph(meta_graph, input_map={'example': example, 'label': label})
    print("{}".format(meta_graph))
    saver.restore(sess, latest_checkpoint)
    print("{}".format(latest_checkpoint))

    train_op = tf.get_collection('train_op')[0]
    init = tf.get_collection('init')[0]

    for i in range(num_epochs):
        sess.run(iterator.initializer)
        while True:
            try:
                sess.run(train_op)
            except tf.errors.OutOfRangeError:
                break
            print(saver.save(sess, saver_prefix))