Tensorflow:使用公制tf.metrics.recall_at_thresholds评估具有不平衡数据的二进制分类器

时间:2017-08-16 19:24:14

标签: python tensorflow

我正在尝试在Tensorflow中评估受不平衡数据训练的二元分类器,并且在使用 tf.metrics.recall_at_thresholds

时遇到问题
one_hot_targets=tf.reshape(tf.one_hot(tf.cast(Y,tf.int32),2),[-1,2]) #Y are labels
weights=tf.reshape(tf.transpose(tf.matmul(one_hot_targets,tf.transpose([[ratio,1.0-ratio]]))),[-1,1])
recalls=tf.metrics.recall_at_thresholds(labels=Y,predictions=func_8,thresholds=[0.2,0.3,0.4,0.6,0.7],weights=weights)

但这会在Session.run()中给出以下错误:

Attempting to use uninitialized value recall_at_thresholds/true_positives
     [[Node: recall_at_thresholds/true_positives/read = Identity[T=DT_FLOAT, _class=["loc:@recall_at_thresholds/true_positives"], _device="/job:localhost/replica:0/task:0/cpu:0"](recall_at_thresholds/true_positives)]]

请注意,会话中使用 feed dict 传递比率 Y

1 个答案:

答案 0 :(得分:1)

似乎你在dint初始化了相关的局部和全局变量

sess.run([tf.global_variables_initializer(), tf.local_variables_initializer()])