我在这里稍微修改了nmt(https://github.com/tensorflow/nmt/blob/tf-1.4/nmt/model.py#L227)
partial_setup = sess.partial_run_setup([self.train_logits, self.train_op_dummy_reward, self.train_loss, self.predict_count, self.train_summary,
self.global_step, self.word_count, self.batch_size], [self.rewards])
scores = sess.partial_run(partial_setup, self.train_logits)
rewards_output = self.calculate_rewards(scores)
step_loss = sess.partial_run(partial_setup, self.train_loss, feed_dict={self.rewards: rewards_output})
print('Loss value:', step_loss)
r = sess.partial_run(partial_setup, self.train_op_dummy_reward)
step_predict_count, step_summary, global_step, step_word_count, batch_size = sess.partial_run(partial_setup, [self.predict_count,self.train_summary, self.global_step, self.word_count,self.batch_size])
return None, step_loss, step_predict_count, step_summary, global_step, step_word_count, batch_size
它运行一个纪元,并因错误而失败:
tensorflow.python.framework.errors_impl.CancelledError: Run call was cancelled
这种情况发生在
train_sess.run(
train_model.iterator.initializer,
feed_dict={train_model.skip_count_placeholder: 0})
被称为(https://github.com/tensorflow/nmt/blob/tf-1.4/nmt/train.py#L260)。