使用Estimator时,如何在评估期间向摘要添加图像?

时间:2018-09-18 14:19:57

标签: python-3.x tensorflow

我在每个时期的末尾进行评估,需要显示根据模型函数features的{​​{1}}和labels参数计算出的图像。在模型函数的评估部分包括model_fn并没有帮助,在我看来,唯一的方法是传递正确的tf.summary.image(name, image)来为模式{构造eval_metric_ops {1}}。因此,我首先将EstimatorSpec子类化,以便它考虑图像。以下代码主要来自estimator.py;唯一的变化是EVAL中用“我的变化”标记的几行:

Estimator

模型功能就像-

_write_dict_to_summary

主要部分-

import logging
import io
import numpy as np
import matplotlib.pyplot as plt
import six
from google.protobuf import message
import tensorflow as tf
from tensorflow.python.training import evaluation
from tensorflow.python import ops
from tensorflow.python.estimator.estimator import _dict_to_str, _write_checkpoint_path_to_summary
from tensorflow.core.framework import summary_pb2
from tensorflow.python.framework import tensor_util
from tensorflow.python.summary.writer import writer_cache


def dump_as_image(a):
    vmin = np.min(a)
    vmax = np.max(a)
    img = np.squeeze((img - vmin) / (vmax - vmin) * 255).astype(np.uint8)
    s = io.BytesIO()
    plt.imsave(s, img, format='png', vmin=0, vmax=255, cmap='gray')
    return s.getvalue()


# see https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/estimator/estimator.py

def _write_dict_to_summary(output_dir, dictionary, current_global_step):
    logging.info('Saving dict for global step %d: %s', current_global_step, _dict_to_str(dictionary))
    summary_writer = writer_cache.FileWriterCache.get(output_dir)
    summary_proto = summary_pb2.Summary()
    for key in dictionary:
        if dictionary[key] is None:
            continue
        if key == 'global_step':
            continue
        if (isinstance(dictionary[key], np.float32) or
            isinstance(dictionary[key], float)):
            summary_proto.value.add(tag=key, simple_value=float(dictionary[key]))
        elif (isinstance(dictionary[key], np.int64) or
              isinstance(dictionary[key], np.int32) or
              isinstance(dictionary[key], int)):
            summary_proto.value.add(tag=key, simple_value=int(dictionary[key]))
        elif isinstance(dictionary[key], six.binary_type):
            try:
                summ = summary_pb2.Summary.FromString(dictionary[key])
                for i, img_bytes in enumerate(summ.value):
                    summ.value[i].tag = '%s/%d' % (key, i)
                summary_proto.value.extend(summ.value)
            except message.DecodeError:
                logging.warn('Skipping summary for %s, cannot parse string to Summary.', key)
                continue
        elif isinstance(dictionary[key], np.ndarray):
            value = summary_proto.value.add()
            value.tag = key
            value.node_name = key
            array = dictionary[key]

            # my change begins
            if array.ndim == 2:
                buffer = dump_as_image(array)
                value.image.encoded_image_string = buffer
            # my change ends

            else:
                tensor_proto = tensor_util.make_tensor_proto(array)
                value.tensor.CopyFrom(tensor_proto)

                logging.info(
                    'Summary for np.ndarray is not visible in Tensorboard by default. '
                    'Consider using a Tensorboard plugin for visualization (see '
                    'https://github.com/tensorflow/tensorboard-plugin-example/blob/master/README.md'
                    ' for more information).')
        else:
            logging.warn(
                'Skipping summary for %s, must be a float, np.float32, np.int64, '
                'np.int32 or int or np.ndarray or a serialized string of Summary.',
                key)
    summary_writer.add_summary(summary_proto, current_global_step)
    summary_writer.flush()


class ImageMonitoringEstimator(tf.estimator.Estimator):

    def __init__(self, *args, **kwargs):
        tf.estimator.Estimator._assert_members_are_not_overridden = lambda self: None
        super(ImageMonitoringEstimator, self).__init__(*args, **kwargs)

    def _evaluate_run(self, checkpoint_path, scaffold, update_op, eval_dict, all_hooks, output_dir):

        eval_results = evaluation._evaluate_once(
            checkpoint_path=checkpoint_path,
            master=self._config.evaluation_master,
            scaffold=scaffold,
            eval_ops=update_op,
            final_ops=eval_dict,
            hooks=all_hooks,
            config=self._session_config)

        current_global_step = eval_results[ops.GraphKeys.GLOBAL_STEP]

        _write_dict_to_summary(
            output_dir=output_dir,
            dictionary=eval_results,
            current_global_step=current_global_step)

        if checkpoint_path:
            _write_checkpoint_path_to_summary(
                output_dir=output_dir,
                checkpoint_path=checkpoint_path,
                current_global_step=current_global_step)

        return eval_results

上面的代码将给出警告(后来出现错误):

  

警告:tensorflow:OutOfRangeError或StopIteration异常是   由FinalOpsHook中的代码引发。这通常意味着操作   由FinalOpsHook运行的依赖关系返回到某些输入   源,这不应该发生。例如,对于   tf.estimator.Estimator,所有指标函数均返回两个操作:   def model_func(features, labels, mode): # calculate network_output if mode == tf.estimator.ModeKeys.TRAIN: # training elif mode == tf.estimator.ModeKeys.EVAL: # make_image consists of slicing and concatenations images = tf.map_fn(make_image, (features, network_output, labels), dtype=features.dtype) eval_metric_ops = images, tf.no_op() # not working return tf.estimator.EstimatorSpec(mode, loss=loss) eval_metric_ops={'images': eval_metric_ops}) else: # prediction # mon_features and mon_labels are np.ndarray estimator = ImageMonitoringEstimator(model_fn=model_func,...) mon_input_func = tf.estimator.inputs.numpy_input_fn(mon_features, mon_labels, shuffle=False, num_epochs=num_epochs, batch_size=len(mon_features)) for _ in range(num_epochs): estimator.train(...) estimator.evaluate(input_fn=mon_input_func) 。 Estimator.evaluate调用value_op   对于输入源中的每批数据,数据用完后,   它调用update_op以获取度量标准值。 update_op在这里   应该具有对只读变量的依赖关系,而不是   从输入中读取另一批。否则,将执行value_op   通过value_op触发另一次数据读取,此操作结束   OutOfRangeError / StopIteration。请解决该问题。

好像我没有正确设置value_op。我猜想FinalOpsHook会碰到另一批警告消息。也许我需要一些eval_metric_ops堆栈操作来构建用于增量监视的图像?但是我不确定该怎么做。那么如何在使用tf.map_fn的评估过程中将图像添加到摘要中?

1 个答案:

答案 0 :(得分:1)

我可以通过在评估模式下传递tf.train.SummarySaverHook并将其声明为tf.estimator.EstimatorSpec处的evaluation_hooks=来使其工作。 images是您要在评估期间打印的所需tf.summary.image的列表。 例如:

eval_summary_hook = tf.train.SummarySaverHook(output_dir=params['eval_save_path'], summary_op=images, save_secs=120)

spec = tf.estimator.EstimatorSpec(mode=mode, predictions=y_pred, loss=loss, eval_metric_ops=eval_metric_ops,
                                      evaluation_hooks=[eval_summary_hook])