TensorFlow自定义估算器:定义估算器规范会触发错误

时间:2018-11-15 10:22:08

标签: python tensorflow tensorflow-serving tensorflow-estimator

在自定义Estimatormodel_fn中,我试图概括一些方面。在摆弄这个想法的同时,我遇到了一些奇怪的事情。

如果我定义了一个EstimatorSpec并返回 _not _ ,它仍然会像返回的一样。包含所有代码的Colab可用。

为证明概念,我只更改了(buggy)_model_fn(buggy)_mode_predict中的几行(下面的代码,还请查看Colab)。

为什么初始化EstimatorSpec(无论范围如何)都会改变Estimator的行为?

model_fn

功能

def model_fn(features, labels, mode, params):
    MODEL = {'features': features, 'labels': labels, 'mode': mode, 'params': params}

    # send the features through the graph
    MODEL = build_fn(MODEL)

    # prediction
    MODEL['predictions'] = {'labels': MODEL['net_logits']}

    MODEL['export_outputs'] = {
        k: tf.estimator.export.PredictOutput(v) for k, v in MODEL['predictions'].items()
    }


    if mode == tf.estimator.ModeKeys.PREDICT: 
      return mode_predict(MODEL)

    # calculate the loss
    MODEL = loss_fn(MODEL)

    # calculate all metrics and send them to tf.summary
    MODEL = metrics_fn(MODEL)

    if mode == tf.estimator.ModeKeys.EVAL: 
      return mode_eval(MODEL)

    if mode == tf.estimator.ModeKeys.TRAIN: 
      return mode_train(MODEL)

越野车

def buggy_model_fn(features, labels, mode, params):
    MODEL = {'features': features, 'labels': labels, 'mode': mode, 'params': params}

    # send the features through the graph
    MODEL = build_fn(MODEL)



    # prediction
    # START BUGS HERE -----------------------------------------------
    MODEL = buggy_mode_predict(MODEL)
    if mode == tf.estimator.ModeKeys.PREDICT:
      return MODEL['PREDICT_SPEC']
    # END BUGS HERE -----------------------------------------------



    # calculate the loss
    MODEL = loss_fn(MODEL)

    # calculate all metrics and send them to tf.summary
    MODEL = metrics_fn(MODEL)

    if mode == tf.estimator.ModeKeys.EVAL: 
      return mode_eval(MODEL)

    if mode == tf.estimator.ModeKeys.TRAIN: 
      return mode_train(MODEL)

mode_predict

功能

def mode_predict(model):
    """How to predict given the model.

    Args:
        model (dict): a `dict` containing the model

    Returns:
        spec (`EstimatorSpec`_): Ops and objects returned from a model_fn and passed to an Estimator

    .. _EstimatorSpec:
        https://www.tensorflow.org/api_docs/python/tf/estimator/EstimatorSpec

    """
    # do the predictions here

    spec = tf.estimator.EstimatorSpec(
        mode           = model['mode'],
        predictions    = model['predictions'],
        export_outputs = model['export_outputs']
    )
    return spec

越野车

def buggy_mode_predict(model):
    # do the predictions here
    model['predictions'] = {'labels': model['net_logits']}

    model['export_outputs'] = {
        k: tf.estimator.export.PredictOutput(v) for k, v in model['predictions'].items()
    }

    spec = tf.estimator.EstimatorSpec(
        mode           = model['mode'],
        predictions    = model['predictions'],
        export_outputs = model['export_outputs']
    )
    # START BUGS HERE -----------------------------------------------
    model['PREDICT_SPEC'] = spec
    # END BUGS HERE -----------------------------------------------
    return model

0 个答案:

没有答案