使用tensorflow.metrics中的sparse_precision_at_k键入不匹配

时间:2017-10-24 21:34:34

标签: tensorflow

我正在使用一个玩具示例来检查tensorflow.metrics.sparse_precision_at_k的工作原理

来自文档:

  

标签:int64 TensorSparseTensor具有形状       [D1,... DN,num_labels]或[D1,... DN],后者暗示       num_labels = 1。 N> = 1且num_labels是目标类的数量       相关的预测。通常,N = 1且labels具有形状       [batch_size,num_labels]。 [D1,... DN]必须与predictions匹配。值       应该在[0,num_classes]范围内,其中num_classes是最后一个       维度predictions。超出此范围的值将被忽略。

     

预测:浮动Tensor形状[D1,... DN,num_classes]在哪里       N> = 1.通常,N = 1并且预测具有形状[批量大小,num_classes]。       最终维度包含每个类的logit值。 [D1,...... DN]       必须与labels匹配。

     

k:整数,k代表@k metric。

所以我写了一个相应的例子:

import tensorflow as tf
import numpy as np

pred = np.asarray([[.8,.1,.1,.1],[.2,.9,.9,.9]]).T
print(pred.shape)

segm = [0,1,1,1]
segm = np.asarray(segm, np.float32)
print(segm.shape)

segm_tf = tf.Variable(segm, np.int64)
pred_tf = tf.Variable(pred, np.float32)

print("segm_tf", segm_tf.shape)
print("pred_tf", pred_tf.shape)

prec,_ = tf.metrics.sparse_precision_at_k(segm_tf, pred_tf, 1, class_id=1)
sess = tf.InteractiveSession()
tf.variables_initializer([prec, segm_tf, pred_tf])

然而,我收到一个错误:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-7-c6243802dedc> in <module>()
     25 print("pred_tf", pred_tf.shape)
     26 
---> 27 prec,_ = tf.metrics.sparse_precision_at_k(segm_tf, pred_tf, 1, class_id=1)
     28 sess = tf.InteractiveSession()
     29 tf.variables_initializer([prec, segm_tf, pred_tf])

/home/ubuntu/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/metrics_impl.py in sparse_precision_at_k(labels, predictions, k, class_id, weights, metrics_collections, updates_collections, name)
   2828         metrics_collections=metrics_collections,
   2829         updates_collections=updates_collections,
-> 2830         name=scope)
   2831 
   2832 

/home/ubuntu/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/metrics_impl.py in _sparse_precision_at_top_k(labels, predictions_idx, k, class_id, weights, metrics_collections, updates_collections, name)
   2726     tp, tp_update = _streaming_sparse_true_positive_at_k(
   2727         predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,
-> 2728         weights=weights)
   2729     fp, fp_update = _streaming_sparse_false_positive_at_k(
   2730         predictions_idx=top_k_idx, labels=labels, k=k, class_id=class_id,

/home/ubuntu/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/metrics_impl.py in _streaming_sparse_true_positive_at_k(labels, predictions_idx, k, class_id, weights, name)
   1743     tp = _sparse_true_positive_at_k(
   1744         predictions_idx=predictions_idx, labels=labels, class_id=class_id,
-> 1745         weights=weights)
   1746     batch_total_tp = math_ops.to_double(math_ops.reduce_sum(tp))
   1747 

/home/ubuntu/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/metrics_impl.py in _sparse_true_positive_at_k(labels, predictions_idx, class_id, weights, name)
   1689       name, 'true_positives', (predictions_idx, labels, weights)):
   1690     labels, predictions_idx = _maybe_select_class_id(
-> 1691         labels, predictions_idx, class_id)
   1692     tp = sets.set_size(sets.set_intersection(predictions_idx, labels))
   1693     tp = math_ops.to_double(tp)

/home/ubuntu/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/metrics_impl.py in _maybe_select_class_id(labels, predictions_idx, selected_id)
   1651   if selected_id is None:
   1652     return labels, predictions_idx
-> 1653   return (_select_class_id(labels, selected_id),
   1654           _select_class_id(predictions_idx, selected_id))
   1655 

/home/ubuntu/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/metrics_impl.py in _select_class_id(ids, selected_id)
   1627   filled_selected_id = array_ops.fill(
   1628       filled_selected_id_shape, math_ops.to_int64(selected_id))
-> 1629   result = sets.set_intersection(filled_selected_id, ids)
   1630   return sparse_tensor.SparseTensor(
   1631       indices=result.indices, values=result.values, dense_shape=ids_shape)

/home/ubuntu/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/sets_impl.py in set_intersection(a, b, validate_indices)
    191     intersections.
    192   """
--> 193   a, b, _ = _convert_to_tensors_or_sparse_tensors(a, b)
    194   return _set_operation(a, b, "intersection", validate_indices)
    195 

/home/ubuntu/anaconda3/lib/python3.6/site-packages/tensorflow/python/ops/sets_impl.py in _convert_to_tensors_or_sparse_tensors(a, b)
     82   b = sparse_tensor.convert_to_tensor_or_sparse_tensor(b, name="b")
     83   if b.dtype.base_dtype != a.dtype.base_dtype:
---> 84     raise TypeError("Types don't match, %s vs %s." % (a.dtype, b.dtype))
     85   if (isinstance(a, sparse_tensor.SparseTensor) and
     86       not isinstance(b, sparse_tensor.SparseTensor)):

TypeError: Types don't match, <dtype: 'int64'> vs <dtype: 'float32'>.

1 个答案:

答案 0 :(得分:0)

以下是使用此指标的简单示例。

sess = tf.Session()
predictions = tf.constant([[0.1, 0.3, 0.2, 0.4], [0.1, 0.2, 0.3, 0.4]],
                          dtype=tf.float32)
labels = tf.constant([3, 2], tf.int64)
precision_op, update_op = tf.metrics.sparse_precision_at_k(
       labels=labels,                                          
       predictions=predictions,
       k=1,
       class_id=3)
sess.run(tf.local_variables_initializer())

print(sess.run(update_op))

此示例打印0.5,因为我们的预测为所有(两个)示例预测了第3级,并且只有其中一个是正确的。

两个返回的操作(precision_opupdate_op)可能令人困惑。请阅读本指南 - https://www.tensorflow.org/api_guides/python/contrib.metrics。它讨论了“流式”指标,但同样的逻辑适用于所有指标。基本上,update_op实际上使用您给出的示例/标签更新变量,precision_op是幂等的 - 它只返回度量的当前值。如果您从未致电update_op,则该指标的当前值未定义,可能为nan

关于您的代码,形状不正确。在最简单的情况下,标签应该为批处理中的每个示例提供正确的标签。在您的情况下,只有两个示例,因此应该只有两个标签。此外,您不需要自己创建变量 - sparse_precision_at_k为您完成。