TensorFlow分类与实际输出不匹配

时间:2017-05-10 16:44:17

标签: machine-learning tensorflow neural-network linear-regression

我在TensorFlow中编写了一个用于线性分类的代码。我根据一条规则生成假数据,“如果差值大于x(某些常数),则输出应为[1,0],否则输出应为[0,1]。这是我的代码

def weight_variable(shape):
     initial = tf.truncated_normal(shape, stddev=0.1)
     return tf.Variable(initial)

def bias_variable(shape):
     initial = tf.constant(0.1, shape=shape)
     return tf.Variable(initial)

def main(_):
# Import data

# Create the model
     x = tf.placeholder(tf.float32, [None, 2])
     W1 = weight_variable([2, 2])
     b1 = bias_variable([2])
     y2 = tf.nn.softmax(tf.matmul(x, W1) + b1)

# Define loss and optimizer'''
     y_ = tf.placeholder(tf.float32, [None, 2])
     cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y2))
     train_step = tf.train.GradientDescentOptimizer(0.001).minimize(cross_entropy)

     sess = tf.InteractiveSession()
     tf.global_variables_initializer().run()
     # Train
     for _ in range(10000):
         batch_xs, batch_ys = data_supplier.next_batch(100)
         sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})

     # Test trained model
     test_batch_x, test_batch_y = data_supplier.test_data()
     correct_prediction = tf.equal(tf.argmax(y2, 1), tf.argmax(y_, 1))
     accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
     print(sess.run(accuracy, feed_dict={x: test_batch_x,
                                  y_: test_batch_y}))
     print(x.eval(feed_dict={x: test_batch_x, y_: test_batch_y}))
     print(y2.eval(feed_dict={x: test_batch_x, y_: test_batch_y}))
     print(y_.eval(feed_dict={x: test_batch_x, y_: test_batch_y}))
     print(W1.eval())
     print(b1.eval())
     print(cross_entropy.eval(feed_dict={x: test_batch_x, y_: test_batch_y}))
if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data',
                      help='Directory for storing input data')
    FLAGS, unparsed = parser.parse_known_args()
    tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)

这是data_supplier代码:

TOTAL_DATA_SIZE = 50000
TRAIN_DATA_SIZE = 40000
VALIDATION_DATA_SIZE = 5000
TEST_DATA_SIZE = 5000
COLUMNS = ["a", "b","output", "outputbar"]
FEATURES = ["a", "b"]
LABELS = ["output", "outputbar"]
training_set = pd.read_csv("train.csv", skipinitialspace=True, skiprows=1, names=COLUMNS)
training_set_features = training_set.as_matrix(columns=FEATURES)
training_set_labels = training_set.as_matrix(columns=LABELS)
test_set = pd.read_csv("test.csv", skipinitialspace=True, skiprows=1, names=COLUMNS)
test_set_features = test_set.as_matrix(columns=FEATURES)
test_set_labels = test_set.as_matrix(columns=LABELS)


def next_batch(BATCH_SIZE):

    k = np.random.randint(0,TRAIN_DATA_SIZE-BATCH_SIZE)
    return training_set_features[k:k+BATCH_SIZE], training_set_labels[k:k+BATCH_SIZE]


def test_data():

    return test_set_features, test_set_labels

这是输出:

accuracy: 0.6852
Input: [[ 0.51166666  0.79333335]
 [ 0.85833335  0.21833333]
 [ 0.80333334  0.48333332]
 ..., 
 [ 0.28333333  0.96499997]
 [ 0.97666669  0.84833336]
 [ 0.57666665  0.21333334]]
Predictions: [[ 0.80804855  0.19195142]
 [ 0.78380686  0.21619321]
 [ 0.80210352  0.19789645]
 ..., 
 [ 0.80708122  0.19291875]
 [ 0.83949649  0.16050354]
 [ 0.76328743  0.23671262]]
Actual output: [[ 0.  1.]
 [ 1.  0.]
 [ 1.  0.]
 ..., 
 [ 1.  0.]
 [ 1.  0.]
 [ 1.  0.]]
Weights: [[ 0.3034386  -0.10369452]
 [ 0.29422989 -0.21103808]]
Bias: [ 0.5141086 -0.3141087]
Cross Entropy: 0.624272

目前,准确度毫无意义,因为每个预测都归类为[1,0]。我犯的错是什么?

1 个答案:

答案 0 :(得分:0)

您应该按照softmax_cross_entropy_with_logits()的设计方式使用logitshttps://www.tensorflow.org/api_docs/python/tf/nn/softmax_cross_entropy_with_logits。 该警告表示该函数需要未缩放y2,即您不应对y2 = tf.matmul(x, W1) + b1执行softmax操作:y_out_test = tf.nn.softmax(y2)。对于测试集,您必须执行softmax操作:x或类似的操作。 也许这已经解决了你的问题。

如果不是,如果仅预测单个类,则通常暗示数据集中的不平衡,即一个类比另一个类更频繁地发生。您应该检查是否属于您的情况。如果是这样,您可能会找到一些关于如何处理此问题的建议,例如:http://machinelearningmastery.com/tactics-to-combat-imbalanced-classes-in-your-machine-learning-dataset/。我没有详细检查这个网站,所以我不能告诉你它是否特别有用,但如果你的数据集严重失衡,你可能想咨询一些网页。