使用Tensorflow进行二进制分类

时间:2017-07-10 01:09:23

标签: python tensorflow logistic-regression

我正在尝试使用tensorflow执行二进制分类,这是为cs20si分配的。这些都非常简单,但我正在学习如何从头开始编写tensorflow来学习复杂的细节,例如设置数据管道,维护检查点等等。我有训练和测试的代码,并且无法达到超过12%的准确率,而sklearn使用相同的模型获得78%的准确率。我理解问题必须在我对tensorflow的代码中。数据来自here,我可以看到我工作的jupyter笔记本here。我已经发布了变量设置,培训和测试代码。我无法找到为什么损失总是在4000s。

VARIABLE SETUP

import glob, os
for f in glob.glob("/tmp/model.ckpt*"):
    os.remove(f)

saver = tf.train.Saver([w,b])
EPOCHS = 1000

with tf.Session() as sess:
    # Step 7: initialize the necessary variables, in this case, w and b
    sess.run(tf.global_variables_initializer())

    # Step 8: train the model
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    n_batches = int(n_train_data/BATCH_SIZE)
    for epoch in tqdm(range(EPOCHS)): # run epochs
        avg_loss = 0

        for _ in range(n_batches):
            x_batch, y_batch = sess.run([data1_feature_batch, data1_label_batch])
            # Session runs train_op to minimize loss
            feed_dict={X: x_batch, Y:y_batch}
            _, loss_batch = sess.run([optimizer, loss], feed_dict=feed_dict)
            avg_loss += loss_batch/n_batches

        if (epoch+1) % 100 == 0:
            print "avg_loss",avg_loss

    coord.request_stop()
    coord.join(threads)

    # Step 9: saving the values of w and b
    print "weights",w.eval()
    print "bias",b.eval()

    # Add ops to save and restore all the variables.
    save_path = saver.save(sess, "/tmp/logit_reg_tf_model.ckpt")

TRAINING

# Step 10: predict
# test the model

saver = tf.train.import_meta_graph("/tmp/logit_reg_tf_model.ckpt.meta")
with tf.Session() as sess:
    # nitialize the necessary variables, in this case, w and b
    sess.run(tf.global_variables_initializer())
    # Add ops to save and restore all the variables.
    saver.restore(sess, "/tmp/logit_reg_tf_model.ckpt")
    print "weights",w.eval()
    print "bias",b.eval()

    total_correct_preds = 0
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)

    try:
        for i in range(20):
            x_batch, y_batch = sess.run([test_data1_feature_batch, test_data1_label_batch])
            total_correct_preds += sess.run(accuracy, feed_dict={X: x_batch, Y:y_batch})

    except tf.errors.OutOfRangeError:
        print('Done testing ...')
    coord.request_stop()
    coord.join(threads)

    print 'Accuracy {0}'.format(total_correct_preds/n_test_data)

测试

#!/usr/bin/env python

import asyncio
import websockets

async def hello(websocket, path):
    name = await websocket.recv()
    print("< {}".format(name))

    greeting = "Hello {}!".format(name)
    await websocket.send(greeting)
    print("> {}".format(greeting))

start_server = websockets.serve(hello, 'localhost', 8765)

asyncio.get_event_loop().run_until_complete(start_server)
asyncio.get_event_loop().run_forever()

1 个答案:

答案 0 :(得分:0)

标准化您的输入,您可以使用sklearn的[lein-figwheel "0.5.11"]。学习率很大,减少说0.01并尝试。权重正则化也非常大,删除它并在以后添加(如果需要)。