Tensorflow尝试使用未初始化的值AUC / AUC / auc / false_positives

时间:2017-06-07 20:49:02

标签: python tensorflow

我正在训练使用CNN进行图像分类。由于我的数据集的大小有限,我使用转移学习。基本上,我使用谷歌在其重新培训示例(https://www.tensorflow.org/tutorials/image_retraining)中证明的预训练网络。

该模型效果很好,并且具有非常好的准确性。但我的数据集是高度不平衡的,这意味着准确性不是判断模型性能的最佳指标。

通过研究不同的解决方案,一些人建议改变采样方法或使用的性能指标。我选择和后者一起去。

Tensorflow提供了很好的指标,包括AUC,精确度,召回等。

现在,这是回归模型的代码: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py

我将以下内容添加到add_evaluation_step(result_tensor, ground_truth_tensor)功能:

  with tf.name_scope('AUC'):
    with tf.name_scope('prediction'):
        prediction = tf.argmax(result_tensor, 1)
    with tf.name_scope('AUC'):
        auc_value = tf.metrics.auc(tf.argmax(ground_truth_tensor, 1), prediction, curve='ROC')


  tf.summary.scalar('accuracy', evaluation_step)
  tf.summary.scalar('AUC', auc_value)

但是我收到了这个错误:

  

Traceback(最近一次调用最后一次):文件   " /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/examples/image_retraining/retrain.py" ;,   1135行       tf.app.run(main = main,argv = [sys.argv [0]] + unparsed)File" /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow /tensorflow/python/platform/app.py" ;,   第44行,在运行中       _sys.exit(main(_sys.argv [:1] + flags_passthrough))文件" /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/examples/image_retraining /retrain.py" ;,   第911行,主要       ground_truth_input:train_ground_truth})文件" /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/client/session.py",   第767行,在运行中       run_metadata_ptr)File" /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/client/session.py",   第965行,在_run       feed_dict_string,options,run_metadata)File" /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/client/session.py",   第1015行,在_do_run中       target_list,options,run_metadata)File" /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/client/session.py",   第1035行,在_do_call中       raise type(e)(node_def,op,message)tensorflow.python.framework.errors_impl.FailedPreconditionError:   试图使用未初始化的值AUC / AUC / auc / false_positives
  [[节点:AUC / AUC / auc / false_positives / read = IdentityT = DT_FLOAT,   _class = [&#34; loc:@ AUC / AUC / auc / false_positives&#34;],_ device =&#34; / job:localhost / replica:0 / task:0 / cpu:0&#34;]] < / p>      

由op u&#39; AUC / AUC / auc / false_positives / read&#39;引起,定义于:文件   &#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/examples/image_retraining/retrain.py" ;,   1135行       tf.app.run(main = main,argv = [sys.argv [0]] + unparsed)File&#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow /tensorflow/python/platform/app.py" ;,   第44行,在运行中       _sys.exit(main(_sys.argv [:1] + flags_passthrough))文件&#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/examples/image_retraining /retrain.py" ;,   第874行,主要       final_tensor,ground_truth_input)File&#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/examples/image_retraining/retrain.py",   第806行,在add_evaluation_step中       auc_value,update_op = tf.metrics.auc(tf.argmax(ground_truth_tensor,1),预测,   curve =&#39; ROC&#39;)文件   &#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ops/metrics_impl.py" ;,   第555行,在auc       标签,预测,阈值,权重)文件&#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ops/metrics_impl.py",   第473行,在_confusion_matrix_at_thresholds中       false_p = _create_local(&#39; false_positives&#39;,shape = [num_thresholds])文件   &#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ops/metrics_impl.py" ;,   第177行,在_create_local中       validate_shape = validate_shape)文件&#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ops/variables.py",   第226行,在 init 中       expected_shape = expected_shape)文件&#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ops/variables.py",   第344行,在_init_from_args中       self._snapshot = array_ops.identity(self._variable,name =&#34; read&#34;)File   &#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ops/gen_array_ops.py" ;,   第1490行,身份       result = _op_def_lib.apply_op(&#34; Identity&#34;,input = input,name = name)文件   &#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/framework/op_def_library.py" ;,   第768行,在apply_op中       op_def = op_def)File&#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/framework/ops.py",   第2402行,在create_op中       original_op = self._default_original_op,op_def = op_def)File&#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/framework/ops.py" ,   第1264行,在 init 中       self._traceback = _extract_stack()

     

FailedPreconditionError(参见上面的回溯):尝试使用   未初始化的值AUC / AUC / auc / false_positives [[节点:   AUC / AUC / auc / false_positives / read = IdentityT = DT_FLOAT,   _class = [&#34; loc:@ AUC / AUC / auc / false_positives&#34;],_ device =&#34; / job:localhost / replica:0 / task:0 / cpu:0&#34;]] < / p>

但我不明白为什么会这样,因为主要是我有这个:

init = tf.global_variables_initializer()
sess.run(init)

1 个答案:

答案 0 :(得分:18)

试试这个:

init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init)