使用Estimator Tensorflow评估培训期间的指标

时间:2017-09-14 20:29:25

标签: python tensorflow

我正在尝试打印R平方以及训练期间的损失。但我无法得到它。我做了以下实现:

更新:

def r_squared_fn(labels, predictions, mode):
    if mode == learn.ModeKeys.INFER:
        return None
    else:
        SST, update_op1 = tf.metrics.mean_squared_error(labels, tf.reduce_mean(labels))
        SSE, update_op2 = tf.metrics.mean_squared_error(labels, predictions)
        with tf.control_dependencies([update_op1,update_op2]):
            return tf.subtract(1.0, tf.div(SSE, SST), name = 'r'), tf.group(update_op1, update_op2)

def model_fn(features, labels, mode):
    predict = prediction()
    loss = model_loss()
    r_squared,_ = r_squared_fn(labels, predict, mode)

    train_op = model_train_op(loss, mode)

    predictions = {"predictions": predict}

    training_hooks = [
        tf.train.LoggingTensorHook(
            tensors = {'training loss': loss,'training r_squared': r_squared})
    ]

    return tf.estimator.EstimatorSpec(
        mode = mode,
        predictions = predictions,
        loss = loss,
        train_op = train_op,
        training_hooks = training_hooks,
    )

0 个答案:

没有答案