我想了解为什么使用tf.estimator.EstimatorSpec()
的原因和地点。我阅读了Tensorflow网站上的文档,但对此一无所知。
请用简单的语言向我解释。
答案 0 :(得分:0)
我第一次阅读API时有点I,因此我写了this repo和基本的explanation。
简而言之:tf.estimator.Estimator
需要一个model_fn
作为输入参数。 model_fn
应该是映射(features, labels, mode, [config, params]) -> tf.estimator.EstimatorSpec
的 function 。 (config和params参数是可选的。)
EstimatorSpec
本身是一个估计量的规范,它包含除输入数据本身(train
/ evaluate
中提供的所有训练,评估和预测所需的一切) / predict
类的/ tf.estimator.Estimator
方法。
Except来自上述存储库:
def get_logits(image):
"""Get logits from image."""
x = image
for filters in (32, 64):
x = tf.layers.conv2d(x, filters, 3)
x = tf.nn.relu(x)
x = tf.layers.max_pooling2d(x, 3, 2)
x = tf.reduce_mean(x, axis=(1, 2))
logits = tf.layers.dense(x, 10)
return logits
def get_estimator_spec(features, labels, mode):
"""
Get an estimator specification.
Args:
features: mnist image batch, flaot32 tensor of shape
(batch_size, 28, 28, 1)
labels: mnist label batch, int32 tensor of shape (batch_size,)
mode: one of `tf.estimator.ModeKeys`, i.e. {"train", "infer", "predict"}
Returns:
tf.estimator.EstimatorSpec
"""
if mode not in {"train", "infer", "eval"}:
raise ValueError('mode should be in {"train", "infer", "eval"}')
logits = get_logits(features)
preds = tf.argmax(logits, axis=-1)
probs = tf.nn.softmax(logits, axis=-1)
predictions = dict(preds=preds, probs=probs, image=features)
if mode == 'infer':
return tf.estimator.EstimatorSpec(mode=mode, predictions=predictions)
loss = tf.losses.sparse_softmax_cross_entropy(labels, logits)
optimizer = tf.train.AdamOptimizer(learning_rate=1e-3)
step = tf.train.get_or_create_global_step()
train_op = optimizer.minimize(loss, global_step=step)
accuracy = tf.metrics.accuracy(labels, preds)
return tf.estimator.EstimatorSpec(
mode=mode, predictions=predictions,
loss=loss, train_op=train_op, eval_metric_ops=dict(accuracy=accuracy))
model_dir = '/tmp/mnist_simple'
def get_estimator():
return tf.estimator.Estimator(get_estimator_spec, model_dir)