了解张量流的估计器类

时间:2018-08-12 21:12:46

标签: python-3.x tensorflow

我想了解为什么使用tf.estimator.EstimatorSpec()的原因和地点。我阅读了Tensorflow网站上的文档,但对此一无所知。

请用简单的语言向我解释。

1 个答案:

答案 0 :(得分:0)

我第一次阅读API时有点I,因此我写了this repo和基本的explanation

简而言之:tf.estimator.Estimator需要一个model_fn作为输入参数。 model_fn应该是映射(features, labels, mode, [config, params]) -> tf.estimator.EstimatorSpec function 。 (config和params参数是可选的。)

EstimatorSpec本身是一个估计量的规范,它包含除输入数据本身(train / evaluate中提供的所有训练,评估和预测所需的一切) / predict类的/ tf.estimator.Estimator方法。

Except来自上述存储库:

def get_logits(image):
    """Get logits from image."""
    x = image
    for filters in (32, 64):
        x = tf.layers.conv2d(x, filters, 3)
        x = tf.nn.relu(x)
        x = tf.layers.max_pooling2d(x, 3, 2)
    x = tf.reduce_mean(x, axis=(1, 2))
    logits = tf.layers.dense(x, 10)
    return logits


def get_estimator_spec(features, labels, mode):
    """
    Get an estimator specification.
    Args:
      features: mnist image batch, flaot32 tensor of shape
          (batch_size, 28, 28, 1)
      labels: mnist label batch, int32 tensor of shape (batch_size,)
      mode: one of `tf.estimator.ModeKeys`, i.e. {"train", "infer", "predict"}
    Returns:
      tf.estimator.EstimatorSpec
    """
    if mode not in {"train", "infer", "eval"}:
        raise ValueError('mode should be in {"train", "infer", "eval"}')

    logits = get_logits(features)
    preds = tf.argmax(logits, axis=-1)
    probs = tf.nn.softmax(logits, axis=-1)
    predictions = dict(preds=preds, probs=probs, image=features)

    if mode == 'infer':
        return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)

    loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)

    optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
    step = tf.train.get_or_create_global_step()
    train_op = optimizer.minimize(loss, global_step=step)

    accuracy = tf.metrics.accuracy(labels, preds)

    return tf.estimator.EstimatorSpec(
        mode=mode, predictions=predictions,
        loss=loss, train_op=train_op, eval_metric_ops=dict(accuracy=accuracy))


model_dir = '/tmp/mnist_simple'


def get_estimator():
    return tf.estimator.Estimator(get_estimator_spec, model_dir)