tf.metrics.accuracy无法正常工作

时间:2017-11-27 23:55:10

标签: python-3.x machine-learning tensorflow linear-regression metrics

我有线性回归模型似乎工作正常,但我想显示模型的准确性。

首先,我初始化变量和占位符......

X_train, X_test, Y_train, Y_test = train_test_split(
    X_data, 
    Y_data, 
    test_size=0.2
)

n_rows = X_train.shape[0]

X = tf.placeholder(tf.float32, [None, 89])
Y = tf.placeholder(tf.float32, [None, 1])

W_shape = tf.TensorShape([89, 1])
b_shape = tf.TensorShape([1])

W = tf.Variable(tf.random_normal(W_shape))
b = tf.Variable(tf.random_normal(b_shape))

pred = tf.add(tf.matmul(X, W), b)

cost = tf.reduce_sum(tf.pow(pred-Y, 2)/(2*n_rows-1))

optimizer = tf.train.GradientDescentOptimizer(FLAGS.learning_rate).minimize(cost)

X_train的形状为(6702, 89)Y_train的形状为(6702, 1)。接下来我运行会话,我显示每个时期的成本以及总MSE ......

init = tf.global_variables_initializer()

with tf.Session() as sess:

    sess.run(init)

    for epoch in range(FLAGS.training_epochs):

        avg_cost = 0

        for (x, y) in zip(X_train, Y_train):

            x = np.reshape(x, (1, 89))
            y = np.reshape(y, (1,1))
            sess.run(optimizer, feed_dict={X:x, Y:y})

        # display logs per epoch step
        if (epoch + 1) % FLAGS.display_step == 0:

            c = sess.run(
                cost, 
                feed_dict={X:X_train, Y:Y_train}
            )

            y_pred = sess.run(pred, feed_dict={X:X_test})
            test_error = r2_score(Y_test, y_pred)
            print(test_error)

            print("Epoch:", '%04d' % (epoch + 1), "cost=", "{:.9f}".format(c))

    print("Optimization Finished!")

    pred_y = sess.run(pred, feed_dict={X:X_test})
    mse = tf.reduce_mean(tf.square(pred_y - Y_test))

    print("MSE: %4f" % sess.run(mse))

这一切似乎都正常。但是,现在我想看看我的模型的准确性,所以我想实现tf.metrics.accuracy。文档说它有2个参数,labelspredictions。我接下来添加了以下内容......

accuracy, accuracy_op = tf.metrics.accuracy(labels=Y_test, predictions=pred)

init_local = tf.local_variables_initializer()

sess.run(init_local)

print(sess.run(accuracy))

显然我需要初始化局部变量,但是我认为我做错了,因为打印出的准确度结果是0.0

我到处寻找一个有效的例子,但我无法让它为我的模型工作,实现它的正确方法是什么?

0 个答案:

没有答案