在自定义Estimator的model_fn
中,我试图概括一些方面。在摆弄这个想法的同时,我遇到了一些奇怪的事情。
如果我定义了一个EstimatorSpec并返回 _not _ ,它仍然会像返回的一样。包含所有代码的Colab可用。
为证明概念,我只更改了(buggy)_model_fn
和(buggy)_mode_predict
中的几行(下面的代码,还请查看Colab)。
为什么初始化EstimatorSpec(无论范围如何)都会改变Estimator的行为?
model_fn
def model_fn(features, labels, mode, params):
MODEL = {'features': features, 'labels': labels, 'mode': mode, 'params': params}
# send the features through the graph
MODEL = build_fn(MODEL)
# prediction
MODEL['predictions'] = {'labels': MODEL['net_logits']}
MODEL['export_outputs'] = {
k: tf.estimator.export.PredictOutput(v) for k, v in MODEL['predictions'].items()
}
if mode == tf.estimator.ModeKeys.PREDICT:
return mode_predict(MODEL)
# calculate the loss
MODEL = loss_fn(MODEL)
# calculate all metrics and send them to tf.summary
MODEL = metrics_fn(MODEL)
if mode == tf.estimator.ModeKeys.EVAL:
return mode_eval(MODEL)
if mode == tf.estimator.ModeKeys.TRAIN:
return mode_train(MODEL)
def buggy_model_fn(features, labels, mode, params):
MODEL = {'features': features, 'labels': labels, 'mode': mode, 'params': params}
# send the features through the graph
MODEL = build_fn(MODEL)
# prediction
# START BUGS HERE -----------------------------------------------
MODEL = buggy_mode_predict(MODEL)
if mode == tf.estimator.ModeKeys.PREDICT:
return MODEL['PREDICT_SPEC']
# END BUGS HERE -----------------------------------------------
# calculate the loss
MODEL = loss_fn(MODEL)
# calculate all metrics and send them to tf.summary
MODEL = metrics_fn(MODEL)
if mode == tf.estimator.ModeKeys.EVAL:
return mode_eval(MODEL)
if mode == tf.estimator.ModeKeys.TRAIN:
return mode_train(MODEL)
def mode_predict(model):
"""How to predict given the model.
Args:
model (dict): a `dict` containing the model
Returns:
spec (`EstimatorSpec`_): Ops and objects returned from a model_fn and passed to an Estimator
.. _EstimatorSpec:
https://www.tensorflow.org/api_docs/python/tf/estimator/EstimatorSpec
"""
# do the predictions here
spec = tf.estimator.EstimatorSpec(
mode = model['mode'],
predictions = model['predictions'],
export_outputs = model['export_outputs']
)
return spec
def buggy_mode_predict(model):
# do the predictions here
model['predictions'] = {'labels': model['net_logits']}
model['export_outputs'] = {
k: tf.estimator.export.PredictOutput(v) for k, v in model['predictions'].items()
}
spec = tf.estimator.EstimatorSpec(
mode = model['mode'],
predictions = model['predictions'],
export_outputs = model['export_outputs']
)
# START BUGS HERE -----------------------------------------------
model['PREDICT_SPEC'] = spec
# END BUGS HERE -----------------------------------------------
return model