Tensorflow argmax给出错误的指示

时间:2017-05-06 01:15:53

标签: tensorflow

我正在使用tensorflow来训练和使用一个小型神经网络(2d分类有两个类),但我有一个非常奇怪的问题而且看不出我做错了什么: 当我仅为测试批次绘制the predictions vs the true labels时,准确度会变为1.,我显然有一些错误分类的样本。在我看来,tf.argmax是错误评估为1的准确性的原因,但显然这不是真正的原因。 无论如何,我通过计算最后一层输出的准确性得出了这个结论:

with tf.name_scope('accuracy'):
    plabel = tf.argmax(y, 1) # vector of predicted label, elem {0,1}^batch_size
    tlabel = tf.argmax(y_, 1) # similar vector of true labels
    correct_predictions = tf.cast(tf.equal(plabel, tlabel), tf.float32) 
...
with tf.Session() as sess:
    batchx, batchy = generate_batch()
    predictions, acc = sess.run([y, accuracy], feed_dict={x: batchx, y_: batchy})
    mistakes = 0.
    for j in range(batch_size):
        if (predictions[j, 0] - predictions[j, 1])*(batchy[j, 0] - batchy[j, 1]) < 0:
            print("mistake: ", predictions[j], batchy[j])
            mistakes += 1./batch_size
    print("Acc = {} / {} = 1-m/b".format(acc, 1. - mistakes))

x和y_是输入张量,y是最后一层,模型已经训练过。

它给了我以下输出:

Acc = 1.0 / 0.86 = 1-m/b

这些值应该相同。

该图还表明真实精度不是1.,或者评估的精度张量不属于与预测(y)相同的运行。

我发现没有任何迹象表明tf.argmax真的是问题所在,我非常绝望。所以,提前感谢任何帮助

2 个答案:

答案 0 :(得分:0)

您的准确度函数似乎是正确的,因为您没有发布整个代码,我建议您计算会话内的准确性,因此每次您都可以打印预测内容和真实标签并跟踪执行情况。

获取predictions = sess.run(y, feed_dict={x: batchx, y_: batchy})之类的预测,然后按print predictions.eval()打印预测,然后再次使用plabel = tf.argmax(predictions, 1)print plabel.eval()print batchyprint (tf.argmax(batchy)).eval()

现在您应该知道每个print语句出了什么问题。

答案 1 :(得分:0)

准确性:

def evaluate(X_data, y_data):
    num_examples = len(X_data)
    total_accuracy = 0
    sess = tf.get_default_session()
    for offset in range(0, num_examples, BATCH_SIZE):
        batch_x, batch_y = X_data[offset:offset+BATCH_SIZE], y_data[offset:offset+BATCH_SIZE]
        accuracy = sess.run(accuracy_operation, feed_dict={x: batch_x, y: batch_y})
        total_accuracy += (accuracy * len(batch_x))
    return total_accuracy / num_examples

获得预测:

prediction = sess.run(tf.argmax(logits, 1), feed_dict={x: train_data})