移动网络总是预测相同的标签

时间:2018-08-01 05:22:38

标签: python tensorflow deep-learning conv-neural-network

我训练了mobilenet_v2模型对8类图像数据集进行分类。我在数据集中使用google的pretrain参数进行微调和再训练,训练过程中准确度快速提高,训练结束时达到98%。 但是当我使用经过训练的模型来预测测试图像时,它总是输出相同的标签。 我也尝试过vgg16。效果很好

这是我的代码

推断

def inference(inputs,is_training=True):
with tf.contrib.slim.arg_scope(mobilenet_v2.training_scope(is_training=is_training)):
    net, endpoints = mobilenet_v2.mobilenet(input_tensor = inputs,num_classes = n_class,conv_defs = V2_18_DEF)
print('mobilenet output',net.get_shape().as_list())
return net

损失

def loss(logit,label):
    losses=[]
    with tf.name_scope('LOSS'):
        class_loss = tf.nn.softmax_cross_entropy_with_logits(labels=label,logits=logit)
        class_loss = tf.reduce_mean(class_loss,axis = 0)
        tf.summary.scalar('class_loss',class_loss) 
    regularization_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
    total_loss = class_loss + regularization_loss

train_step

def train_op(loss):
        train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)#
        return train_step

测试

def test():

    inputs = tf.placeholder(name='inputs',shape=[None,224,224,3],dtype=tf.float32)

    label_holder = tf.placeholder(name='label_holder',shape=[None,n_class],dtype=tf.float32)

    if(net_type=='mobile_v2'): 
        test_logits = inference(inputs,is_training=False)
    elif(net_type=='mobile_v1'):
        test_logits = inference_mobile_v1(inputs,is_training=False)
    elif net_type=='vgg16':
        test_logits = inference_vgg(inputs,is_training=False)
    predict = tf.nn.softmax(test_logits)
    predict_result = tf.argmax(predict,axis=1,output_type=tf.int32)
    true_result = tf.argmax(label_holder,axis=1,output_type=tf.int32)
    correct_predict = tf.equal(tf.argmax(predict,axis=1,output_type=tf.int32),tf.argmax(label_holder,axis=1,\
             output_type = tf.int32))
    accuracy = tf.reduce_mean(tf.cast(correct_predict,tf.float32))
    pos = 0
    acc = 0
    count = 0
    sess = tf.Session()

    ckpt_path=''
    try:
        ckpt_path = sys.argv[3]
    except:
        ckpt = tf.train.get_checkpoint_state(output_path)
        if(ckpt and ckpt.model_checkpoint_path):
            ckpt_path = ckpt.model_checkpoint_path

    sess.run(tf.global_variables_initializer())
    variables_to_restore = slim.get_variables_to_restore()
    for var in variables_to_restore:
        print(var.name)
    saver = tf.train.Saver(variables_to_restore)
    saver.restore(sess,ckpt_path)


    dataset = Dataset(classes,0.8)
    test_data,test_label = dataset.get_test_data()
    while(pos<len(test_data)):
        start = pos
        end = min(pos+batch_size,len(test_data))
        batch_img = test_data[start:end]
        batch_label = test_label[start:end]
        batch_data={}
        batch_data[inputs] = batch_img
        batch_data[label_holder] = batch_label
        batch_acc,p_result,t_result = sess.run([accuracy,predict_result,true_result],feed_dict=batch_data)
        print('batch_acc',batch_acc)
        print(p_result)
        print(t_result)
        acc += batch_acc
        pos = end
        count+=1
    acc = acc/count
    print('test acc',acc)
    sess.close()

火车功能

def train():

    inputs = tf.placeholder(name='inputs',shape=[None,224,224,3],dtype=tf.float32)#
    labels_placeholder = {}
    label_holder = tf.placeholder(name='label_holder',shape=[None,n_class],dtype=tf.float32) #
    if(net_type=='mobile_v2'):
        train_logits = inference(inputs)
    elif(net_type=='mobile_v1'):
        train_logits = inference_mobile_v1(inputs,is_training=True)
    elif net_type=='vgg16':
        train_logits = inference_vgg(inputs,is_training=True)
    loss_op = loss(train_logits,label_holder)
    predict = tf.nn.softmax(train_logits)
    print('predict shape',predict.get_shape().as_list())
    predict_result = tf.argmax(predict,axis=1,output_type = tf.int32)
    correct_predict = tf.equal(tf.argmax(predict,axis=1,output_type=tf.int32),tf.argmax(label_holder,axis=1,\
             output_type = tf.int32))
    accuracy = tf.reduce_mean(tf.cast(correct_predict,tf.float32))

    for var in tf.trainable_variables():
        tf.summary.histogram(var.name,var)

    train_step = train_op(loss_op)

    #pretrain_restore
    all_variable = tf.trainable_variables()
    pretrain_vals=[]
    reader = pywrap_tensorflow.NewCheckpointReader(pretrain_model_path)
    var_to_shape_map=reader.get_variable_to_shape_map()
    for var in all_variable:
        print(var.name)
        if('Logits' in var.name):
            continue
        if(var.name.split(':')[0] in var_to_shape_map):
            print('restore',var.name)
            pretrain_vals.append(var)
    pretrain_saver = tf.train.Saver(pretrain_vals)

    variables_to_restore = slim.get_variables_to_restore()
    train_saver = tf.train.Saver(variables_to_restore)
    #tensorboard
    sess = tf.Session()
    merge_op = tf.summary.merge_all()
    summary_writer = tf.summary.FileWriter('./tensorboard/train',sess.graph)

    #init
    print('init & restore')
    sess.run(tf.global_variables_initializer())
    pretrain_saver.restore(sess,pretrain_model_path)

    dataset = Dataset(classes,0.8)#
    data_amount = dataset.data_amount 
    each_epoch = int(data_amount/batch_size)
    total_step = int(each_epoch*EPOCH)
    global_step = 0
    test_data,test_label = dataset.next_batch(80)
    for epoch in range(EPOCH):
        pos = 0
        count = 0
        acc = 0
        while(pos<len(test_data)):
            start = pos
            end = min(pos+batch_size,len(test_data))
            batch_img = test_data[start:end]
            batch_label = test_label[start:end]
            batch_data={}
            batch_data[inputs] = batch_img
            batch_data[label_holder] = batch_label
            batch_acc,t_result ,loss_val= sess.run([accuracy,predict_result,loss_op],feed_dict=batch_data)
            print('batch_acc',batch_acc,loss_val)
            print(t_result)
            acc += batch_acc
            pos = end
            count+=1
        acc = acc/count
        print('=====test_acc===',acc)
        for epoch_step in range(each_epoch):

            batch_img,batch_label = dataset.next_batch(batch_size)#
            batch_data={}
            batch_data[inputs] = batch_img
            batch_data[label_holder] = batch_label
            merge_str,loss_val,acc,p_result= sess.run([merge_op,loss_op,accuracy,predict_result],feed_dict=batch_data)
            sess.run(train_step,feed_dict = batch_data)
            print('loss %f,acc %f, global step %d ,epoch %d,epoch_step %d' % (loss_val,acc,global_step,epoch,epoch_step) )
            print(p_result)
            summary_writer.add_summary(merge_str,global_step=global_step)
            summary_writer.flush()
            global_step+=1

        save_path = os.path.join(output_path,'model.ckpt')
        train_saver.save(sess,save_path,global_step)

enter image description here

0 个答案:

没有答案