我恢复了训练有素的模型,如下所示
saver = tf.train.import_meta_graph('expr1.multi/train_logs/model.ckpt-44.meta')
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
saver.restore(sess,'expr1.multi/train_logs/model.ckpt-44')
然后检索推理所需的张量
logits = graph.get_tensor_by_name('strided_slice_1:0')
logits_len = graph.get_tensor_by_name('strided_slice_2:0')
targets = graph.get_tensor_by_name('evaluate/IteratorGetNext:2')
targets_len = graph.get_tensor_by_name('evaluate/IteratorGetNext:3')
# this is to retrieve the dataset.iterator.initializer operator
init_op = graph.get_operation_by_name('evaluate/MakeIterator')
然后进行推理
sess.run(init_op)
while True:
try:
l, ll, t, tl = sess.run([logits, logits_len, targets, targets_len])
...
except tf.errors.OutOfRangeError:
并且对于使用单个gpu训练训练的模型,上面的模型恢复和推理工作正常没有问题。但是,通过以下多个gpu训练(异步)实现,它失败了
loss_ops = []
train_ops = []
for gpu_i in range(self.num_gpus):
with tf.device("/gpu:%d" % gpu_i):
loss = ...
# update the model parameter
update = self.model.update(loss, global_step, self.lrate, self.grad_clip)
loss_ops.append(loss)
train_ops.append(update)
# within this Dataset pipe is created for evaluation data
evaluator = get_evaluator(self.evaluator)(self.conf, self.model)
mon_sess = tf.train.MonitoredTrainingSession(config=config, hooks=hooks)
...
def train_helper(train_op, loss_op):
...
while not mon_sess.should_stop() and has_data:
try:
_, lossVal, step = mon_sess.run([train_op, loss_op, global_step])
except tf.errors.OutOfRangeError:
has_data = False
continue
#validation loss evaluation every eval_steps steps
if not step == 0 and step % self.eval_steps == 0:
# one shot iterator initialization done inside evaluate function
val_lossVal, num_utts = evaluator.evaluate(mon_sess)
...
train_threads = []
for t_op, loss_op in zip(train_ops, loss_ops):
train_threads.append(threading.Thread(target=train_helper, args=(t_op, loss_op)))
# Start the threads, and block on their completion.
for t in train_threads:
t.start()
for t in train_threads:
t.join()
下面的数据集管道迭代器初始化失败错误
tensorflow.python.framework.errors_impl.FailedPreconditionError:GetNext()失败,因为迭代器尚未初始化。确保在获取下一个元素之前已经为此迭代器运行了初始化操作。
我无法弄清楚原因。