TensorFlow自定义估算器中的内存泄漏评估指标主机调用函数

时间:2018-11-22 13:27:23

标签: tensorflow memory-leaks tensorboard

我正在使用以下功能来为我的训练计算额外的指标。我使用以下命令创建一个host_call: host_call =(host_call_fn,metric_args),并将其传递给estimator的host_call参数。但是,调用此方法会导致内存泄漏,并且我无法弄清楚问题出在哪里。使用堆,似乎以某种方式制作了大词典,但它们没有被释放。

p_temp = tf.reshape(policy_loss, [1], name='policy_loss_reshape')
v_temp = tf.reshape(value_loss, [1], name='value_loss_reshape')
e_temp = tf.reshape(entropy_loss, [1], name='entropy_loss_reshape')
t_temp = tf.reshape(total_loss, [1], name='total_loss_reshape')
g_temp = tf.reshape(global_step, [1], name='global_step_reshape')
#
metric_args = [p_temp, v_temp, e_temp, t_temp, g_temp]

host_call_fn = functools.partial(
  eval_metrics_host_call_fn, est_mode=tf.estimator.ModeKeys.TRAIN)
host_call = (host_call_fn, metric_args)

以下函数计算额外的评估指标,并将其写入Tensorboard的摘要目录。

def eval_metrics_host_call_fn(p_temp,
                            v_temp,
                            e_temp,
                            t_temp,
                            step,
                            est_mode=tf.estimator.ModeKeys.TRAIN):
#
with tf.variable_scope('metrics'):
  metric_ops = {
      'policy_loss': tf.metrics.mean(p_temp, name='policy_loss_metric'),
      'value_loss': tf.metrics.mean(v_temp, name='value_loss_metric'),
      'entropy_loss': tf.metrics.mean(e_temp, name='entropy_loss_metric'),
      'total_loss': tf.metrics.mean(t_temp, name='total_loss_metric')
  }
if est_mode == tf.estimator.ModeKeys.EVAL:
  return metric_ops
eval_step = tf.reduce_min(step)
# Create summary ops so that they show up in SUMMARIES collection
# That way, they get logged automatically during training
summary_writer = summary.create_file_writer(FLAGS.summary_dir)
with summary_writer.as_default(
), summary.record_summaries_every_n_global_steps(FLAGS.summary_steps,
                                                 eval_step):
  for metric_name, metric_op in metric_ops.items():
    summary.scalar(metric_name, metric_op[1], step=eval_step)
# Reset metrics occasionally so that they are mean of recent batches.
reset_op = tf.variables_initializer(tf.local_variables('metrics'))
cond_reset_op = tf.cond(
    tf.equal(eval_step % FLAGS.summary_steps, tf.to_int64(1)),
    lambda: reset_op, lambda: tf.no_op())

return summary.all_summary_ops() + [cond_reset_op]

0 个答案:

没有答案