我正在训练使用CNN进行图像分类。由于我的数据集的大小有限,我使用转移学习。基本上,我使用谷歌在其重新培训示例(https://www.tensorflow.org/tutorials/image_retraining)中证明的预训练网络。
该模型效果很好,并且具有非常好的准确性。但我的数据集是高度不平衡的,这意味着准确性不是判断模型性能的最佳指标。
通过研究不同的解决方案,一些人建议改变采样方法或使用的性能指标。我选择和后者一起去。
Tensorflow提供了很好的指标,包括AUC,精确度,召回等。
现在,这是回归模型的代码: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/image_retraining/retrain.py
我将以下内容添加到add_evaluation_step(result_tensor, ground_truth_tensor)
功能:
with tf.name_scope('AUC'):
with tf.name_scope('prediction'):
prediction = tf.argmax(result_tensor, 1)
with tf.name_scope('AUC'):
auc_value = tf.metrics.auc(tf.argmax(ground_truth_tensor, 1), prediction, curve='ROC')
tf.summary.scalar('accuracy', evaluation_step)
tf.summary.scalar('AUC', auc_value)
但是我收到了这个错误:
Traceback(最近一次调用最后一次):文件 " /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/examples/image_retraining/retrain.py" ;, 1135行 tf.app.run(main = main,argv = [sys.argv [0]] + unparsed)File" /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow /tensorflow/python/platform/app.py" ;, 第44行,在运行中 _sys.exit(main(_sys.argv [:1] + flags_passthrough))文件" /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/examples/image_retraining /retrain.py" ;, 第911行,主要 ground_truth_input:train_ground_truth})文件" /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/client/session.py", 第767行,在运行中 run_metadata_ptr)File" /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/client/session.py", 第965行,在_run feed_dict_string,options,run_metadata)File" /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/client/session.py", 第1015行,在_do_run中 target_list,options,run_metadata)File" /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/client/session.py", 第1035行,在_do_call中 raise type(e)(node_def,op,message)tensorflow.python.framework.errors_impl.FailedPreconditionError: 试图使用未初始化的值AUC / AUC / auc / false_positives
[[节点:AUC / AUC / auc / false_positives / read = IdentityT = DT_FLOAT, _class = [&#34; loc:@ AUC / AUC / auc / false_positives&#34;],_ device =&#34; / job:localhost / replica:0 / task:0 / cpu:0&#34;]] < / p>由op u&#39; AUC / AUC / auc / false_positives / read&#39;引起,定义于:文件 &#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/examples/image_retraining/retrain.py" ;, 1135行 tf.app.run(main = main,argv = [sys.argv [0]] + unparsed)File&#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow /tensorflow/python/platform/app.py" ;, 第44行,在运行中 _sys.exit(main(_sys.argv [:1] + flags_passthrough))文件&#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/examples/image_retraining /retrain.py" ;, 第874行,主要 final_tensor,ground_truth_input)File&#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/examples/image_retraining/retrain.py", 第806行,在add_evaluation_step中 auc_value,update_op = tf.metrics.auc(tf.argmax(ground_truth_tensor,1),预测, curve =&#39; ROC&#39;)文件 &#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ops/metrics_impl.py" ;, 第555行,在auc 标签,预测,阈值,权重)文件&#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ops/metrics_impl.py", 第473行,在_confusion_matrix_at_thresholds中 false_p = _create_local(&#39; false_positives&#39;,shape = [num_thresholds])文件 &#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ops/metrics_impl.py" ;, 第177行,在_create_local中 validate_shape = validate_shape)文件&#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ops/variables.py", 第226行,在 init 中 expected_shape = expected_shape)文件&#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ops/variables.py", 第344行,在_init_from_args中 self._snapshot = array_ops.identity(self._variable,name =&#34; read&#34;)File &#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/ops/gen_array_ops.py" ;, 第1490行,身份 result = _op_def_lib.apply_op(&#34; Identity&#34;,input = input,name = name)文件 &#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/framework/op_def_library.py" ;, 第768行,在apply_op中 op_def = op_def)File&#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/framework/ops.py", 第2402行,在create_op中 original_op = self._default_original_op,op_def = op_def)File&#34; /home/user_2/tensorflow/bazel-bin/tensorflow/examples/image_retraining/retrain.runfiles/org_tensorflow/tensorflow/python/framework/ops.py" , 第1264行,在 init 中 self._traceback = _extract_stack()
FailedPreconditionError(参见上面的回溯):尝试使用 未初始化的值AUC / AUC / auc / false_positives [[节点: AUC / AUC / auc / false_positives / read = IdentityT = DT_FLOAT, _class = [&#34; loc:@ AUC / AUC / auc / false_positives&#34;],_ device =&#34; / job:localhost / replica:0 / task:0 / cpu:0&#34;]] < / p>
但我不明白为什么会这样,因为主要是我有这个:
init = tf.global_variables_initializer()
sess.run(init)
答案 0 :(得分:18)
试试这个:
init = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
sess.run(init)