我正在使用一个玩具示例来检查tensorflow.metrics.sparse_precision_at_k
的工作原理
来自文档:
标签:
int64
Tensor
或SparseTensor
具有形状 [D1,... DN,num_labels]或[D1,... DN],后者暗示 num_labels = 1。 N> = 1且num_labels是目标类的数量 相关的预测。通常,N = 1且labels
具有形状 [batch_size,num_labels]。 [D1,... DN]必须与predictions
匹配。值 应该在[0,num_classes]范围内,其中num_classes是最后一个 维度predictions
。超出此范围的值将被忽略。预测:浮动
Tensor
形状[D1,... DN,num_classes]在哪里 N> = 1.通常,N = 1并且预测具有形状[批量大小,num_classes]。 最终维度包含每个类的logit值。 [D1,...... DN] 必须与labels
匹配。k:整数,k代表@k metric。
所以我写了一个相应的例子:
import tensorflow as tf
import numpy as np
pred = np.asarray([[.8,.1,.1,.1],[.2,.9,.9,.9]]).T
print(pred.shape)
segm = [0,1,1,1]
segm = np.asarray(segm, np.float32)
print(segm.shape)
segm_tf = tf.Variable(segm, np.int64)
pred_tf = tf.Variable(pred, np.float32)
print("segm_tf", segm_tf.shape)
print("pred_tf", pred_tf.shape)
prec,_ = tf.metrics.sparse_precision_at_k(segm_tf, pred_tf, 1, class_id=1)
sess = tf.InteractiveSession()
tf.variables_initializer([prec, segm_tf, pred_tf])
然而,我收到一个错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-7-c6243802dedc> in <module>()
25 print("pred_tf", pred_tf.shape)
26
---> 27 prec,_ = tf.metrics.sparse_precision_at_k(segm_tf, pred_tf, 1, class_id=1)
28 sess = tf.InteractiveSession()
29 tf.variables_initializer([prec, segm_tf, pred_tf])
/home/ubuntu/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/metrics_impl.py in sparse_precision_at_k(labels, predictions, k, class_id, weights, metrics_collections, updates_collections, name)
2828 metrics_collections=metrics_collections,
2829 updates_collections=updates_collections,
-> 2830 name=scope)
2831
2832
/home/ubuntu/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/metrics_impl.py in _sparse_precision_at_top_k(labels, predictions_idx, k, class_id, weights, metrics_collections, updates_collections, name)
2726 tp, tp_update = _streaming_sparse_true_positive_at_k(
2727 predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
-> 2728 weights=weights)
2729 fp, fp_update = _streaming_sparse_false_positive_at_k(
2730 predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
/home/ubuntu/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/metrics_impl.py in _streaming_sparse_true_positive_at_k(labels, predictions_idx, k, class_id, weights, name)
1743 tp = _sparse_true_positive_at_k(
1744 predictions_idx=predictions_idx, labels=labels, class_id=class_id,
-> 1745 weights=weights)
1746 batch_total_tp = math_ops.to_double(math_ops.reduce_sum(tp))
1747
/home/ubuntu/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/metrics_impl.py in _sparse_true_positive_at_k(labels, predictions_idx, class_id, weights, name)
1689 name, 'true_positives', (predictions_idx, labels, weights)):
1690 labels, predictions_idx = _maybe_select_class_id(
-> 1691 labels, predictions_idx, class_id)
1692 tp = sets.set_size(sets.set_intersection(predictions_idx, labels))
1693 tp = math_ops.to_double(tp)
/home/ubuntu/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/metrics_impl.py in _maybe_select_class_id(labels, predictions_idx, selected_id)
1651 if selected_id is None:
1652 return labels, predictions_idx
-> 1653 return (_select_class_id(labels, selected_id),
1654 _select_class_id(predictions_idx, selected_id))
1655
/home/ubuntu/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/metrics_impl.py in _select_class_id(ids, selected_id)
1627 filled_selected_id = array_ops.fill(
1628 filled_selected_id_shape, math_ops.to_int64(selected_id))
-> 1629 result = sets.set_intersection(filled_selected_id, ids)
1630 return sparse_tensor.SparseTensor(
1631 indices=result.indices, values=result.values, dense_shape=ids_shape)
/home/ubuntu/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/sets_impl.py in set_intersection(a, b, validate_indices)
191 intersections.
192 """
--> 193 a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b)
194 return _set_operation(a, b, "intersection", validate_indices)
195
/home/ubuntu/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/sets_impl.py in _convert_to_tensors_or_sparse_tensors(a, b)
82 b = sparse_tensor.convert_to_tensor_or_sparse_tensor(b, name="b")
83 if b.dtype.base_dtype != a.dtype.base_dtype:
---> 84 raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype))
85 if (isinstance(a, sparse_tensor.SparseTensor) and
86 not isinstance(b, sparse_tensor.SparseTensor)):
TypeError: Types don't match, <dtype: 'int64'> vs <dtype: 'float32'>.
答案 0 :(得分:0)
以下是使用此指标的简单示例。
sess = tf.Session()
predictions = tf.constant([[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]],
dtype=tf.float32)
labels = tf.constant([3, 2], tf.int64)
precision_op, update_op = tf.metrics.sparse_precision_at_k(
labels=labels,
predictions=predictions,
k=1,
class_id=3)
sess.run(tf.local_variables_initializer())
print(sess.run(update_op))
此示例打印0.5,因为我们的预测为所有(两个)示例预测了第3级,并且只有其中一个是正确的。
两个返回的操作(precision_op
和update_op
)可能令人困惑。请阅读本指南 - https://www.tensorflow.org/api_guides/python/contrib.metrics。它讨论了“流式”指标,但同样的逻辑适用于所有指标。基本上,update_op
实际上使用您给出的示例/标签更新变量,precision_op
是幂等的 - 它只返回度量的当前值。如果您从未致电update_op
,则该指标的当前值未定义,可能为nan
。
关于您的代码,形状不正确。在最简单的情况下,标签应该为批处理中的每个示例提供正确的标签。在您的情况下,只有两个示例,因此应该只有两个标签。此外,您不需要自己创建变量 - sparse_precision_at_k
为您完成。