我正在使用CNN进行背景扣除,但是我的模型学习不正确

时间:2019-02-28 08:36:26

标签: tensorflow machine-learning deep-learning computer-vision google-colaboratory

My test accuracy is constant and train accuracy keep going up and down

我正在通过一次传递两个图像来训练我的模型。一个带有前景,另一个不带有前景。

with tf.Session() as sess:

sess.run(init)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess,coord=coord)
input_test,outputGT_test = build_img_pair(sess.run(test_batch))

for i in range(max_iteration):

    input_train,outputGT_train = build_img_pair(sess.run(train_batch))
    if(i<500):
        train.run({input_image:input_train,gt:outputGT_train,learning_rate:0.001,is_train:False,batch_size:train_batch_size}) 

    elif(i<1500):
        train.run({input_image:input_train,gt:outputGT_train,learning_rate:0.0005,is_train:False,batch_size:train_batch_size}) 

    else:
        train.run({input_image:input_train,gt:outputGT_train,learning_rate:0.0001,is_train:False,batch_size:train_batch_size}) 
    #sess.run(train,feed_dict={input_image:input_train,gt:outputGT_train})


    # PRINT OUT A MESSAGE EVERY 100 STEPS
    if i%10 == 0:
        # Test the Train Model

        print('Currently on step {}'.format(i))
        print('Test Accuracy is:', end = "     ")

        matches = tf.equal(tf.argmax(convo_final,1),tf.argmax(gt,1))

        acc = tf.reduce_mean(tf.cast(matches,tf.float32))

        print(sess.run(acc,{input_image:input_test, gt:outputGT_test,learning_rate:0.001,is_train:False,
                                         batch_size:test_batch_size}),end = "     ")

        print('train Accuracy is:', end = "     ")
        print(sess.run(acc,{input_image:input_train, gt:outputGT_train,learning_rate:0.001,is_train:False,
                                         batch_size:train_batch_size}))

       # test_loss   = cross_entropy.eval({input_image:input_test, gt:outputGT_test,
        #                                        batch_size:test_batch_size,keep_prob:0.5})
        #print(test_loss, end = " ")
        print('\n')

    if(i%500==0):
        save_path = saver.save(sess, "data/model.ckpt")
        print("Model saved in path: %s" % save_path)
coord.request_stop()
coord.join(threads)
tf.train.write_graph(sess.graph.as_graph_def(), 'graph/', 'Model.pb', as_text=True)

0 个答案:

没有答案