我正在尝试使用Covertype数据集(https://archive.ics.uci.edu/ml/datasets/covertype)来构建没有Keras的MLP,与使用Keras(0.85预测)相比,我的准确性非常低。如何提高训练数据的准确性?
数据集已被拆分,并进行了相应的热编码
X_train.shape, y_train.shape, X_test.shape, y_test.shape
((406703, 54), (406703, 7), (174309, 54), (174309, 7))
X = tf.placeholder(tf.float32, [None, 54])
y_true = tf.placeholder(tf.float32, [None, 7])
def mlp(x):
w1 = tf.Variable(tf.random_uniform([54,18]))
b1 = tf.Variable(tf.zeros([18]))
h1 = tf.nn.relu(tf.matmul(x, w1) + b1)
# output layer
w3 = tf.Variable(tf.random_uniform([18,7]))
b3 = tf.Variable(tf.zeros([7]))
logits= tf.matmul(h1, w3) + b3
return logits
logits = mlp(X)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=y_true))
training_op = tf.train.AdamOptimizer(learning_rate=0.01).minimize(cost)
# initialize
init = tf.global_variables_initializer()
# train hyperparameters
epoch_cnt = 30
batch_size = 1000
iteration = len(X_train) // batch_size
# Start training
with tf.Session() as sess:
# Run the initializer
sess.run(init)
for epoch in range(epoch_cnt):
avg_loss = 0.
start = 0; end = batch_size
for i in range(iteration):
_, loss = sess.run([training_op, cost],
feed_dict={X: X_train[start: end], y_true: y_train[start: end]})
start += batch_size; end += batch_size
# Compute average loss
avg_loss += loss / iteration
# Validate model
preds = tf.nn.softmax(logits) # Apply softmax to logits
correct_prediction = tf.equal(tf.argmax(preds, 1), tf.argmax(y_true, 1))
# Calculate accuracy
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
cur_val_acc = accuracy.eval({X: X_test, y_true: y_test})
#cur_val_acc = accuracy.eval({X: X_train, y_true: y_train})
print("epoch: "+str(epoch)+", validation accuracy: " + str(cur_val_acc) +', loss: '+str(avg_loss))
# Test model
preds = tf.nn.softmax(logits) # Apply softmax to logits
correct_prediction = tf.equal(tf.argmax(preds, 1), tf.argmax(y_true, 1))
# Calculate accuracy
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
print("[Test Accuracy] :", accuracy.eval({X: X_test, y_true: y_test}))
我的结果如下:
epoch: 0, validation accuracy: 0.03530512, loss: 1143.4326061579973
epoch: 1, validation accuracy: 0.48712343, loss: 173.77011137162083
epoch: 2, validation accuracy: 0.48746192, loss: 1.329379230957901
epoch: 3, validation accuracy: 0.48755372, loss: 1.2196576229016765
epoch: 4, validation accuracy: 0.4875709, loss: 1.2079436527978022
epoch: 5, validation accuracy: 0.4875652, loss: 1.2053790902122492
epoch: 6, validation accuracy: 0.48757666, loss: 1.2046689584924664
epoch: 7, validation accuracy: 0.48757666, loss: 1.2044714586855156
epoch: 8, validation accuracy: 0.4875824, loss: 1.2044040254991633
epoch: 9, validation accuracy: 0.4875824, loss: 1.2043715470603535
epoch: 10, validation accuracy: 0.4875824, loss: 1.2043672442216014
epoch: 11, validation accuracy: 0.4875824, loss: 1.204392935732021
epoch: 12, validation accuracy: 0.4875824, loss: 1.2043428919805674
epoch: 13, validation accuracy: 0.4875824, loss: 1.2043236044093304
epoch: 14, validation accuracy: 0.4875824, loss: 1.204309070235108
epoch: 15, validation accuracy: 0.4875824, loss: 1.2043108996279133
epoch: 16, validation accuracy: 0.4875824, loss: 1.2043230445881186
epoch: 17, validation accuracy: 0.4875824, loss: 1.2043319362504727
epoch: 18, validation accuracy: 0.4875824, loss: 1.2043379787814439
epoch: 19, validation accuracy: 0.4875824, loss: 1.2043409186777816
epoch: 20, validation accuracy: 0.4875824, loss: 1.2043399413304368
epoch: 21, validation accuracy: 0.4875824, loss: 1.2043371798810107
epoch: 22, validation accuracy: 0.4875824, loss: 1.204332315752953
epoch: 23, validation accuracy: 0.4875824, loss: 1.2043318411179365
epoch: 24, validation accuracy: 0.4875824, loss: 1.2043327070368917
epoch: 25, validation accuracy: 0.4875824, loss: 1.2043332901641064
epoch: 26, validation accuracy: 0.4875824, loss: 1.2043333211776064
epoch: 27, validation accuracy: 0.4875824, loss: 1.2043338380570483
epoch: 28, validation accuracy: 0.4875824, loss: 1.2043338463517836
epoch: 29, validation accuracy: 0.4875824, loss: 1.2043339615972168
[Test Accuracy] : 0.4875824