在阈值处实施Tensorflow的指标API会产生挤压尺寸错误

时间:2019-02-14 22:25:26

标签: tensorflow

我一直在尝试添加tf.metrics.false_positives_at_thresholds(),但是它一直给我以下错误:

InvalidArgumentError: Tried to explicitly squeeze dimension 1 but dimension was not 1: 10
     [[Node: my_metric_7/remove_squeezable_dimensions/cond/Squeeze = Squeeze[T=DT_FLOAT, squeeze_dims=[-1], _device="/job:localhost/replica:0/task:0/device:CPU:0"](my_metric_7/remove_squeezable_dimensions/cond/Switch_1:1, ^my_metric_7/assert_greater_equal/Assert/AssertGuard/Merge, ^my_metric_7/assert_less_equal/Assert/AssertGuard/Merge)]]

此实现是用于调用此API的代码的一部分,我将tensorflow 1.12与python3.x结合使用:

y2 = tf.placeholder(tf.float32, (None))
predictionpercentage = tf.placeholder(tf.float32, (None))

tf_false_threshold, tf_metric_false_threshold = tf.metrics.false_positives_at_thresholds(y2, predictionpercentage,thresholds,name="my_metric")

running_vars = tf.get_collection(tf.GraphKeys.LOCAL_VARIABLES, scope="my_metric")
running_vars_initializer = tf.variables_initializer(var_list=running_vars)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    sess.run(running_vars_initializer)
    batch_x, batch_y = X_data, y_data
    prediction = sess.run(logits,feed_dict={x: batch_x})
    predictionpro = tf.nn.softmax(prediction)
    predicionprobability = sess.run(predictionpro)
    feed_dict={y2:batch_y, predictionpercentage: 
    predicionprobability}   
    sess.run(tf_metric_false_threshold,feed_dict=feed_dict)
    falsepositives_threshold = sess.run(tf_false_threshold)

0 个答案:

没有答案