我训练了mobilenet_v2模型对8类图像数据集进行分类。我在数据集中使用google的pretrain参数进行微调和再训练,训练过程中准确度快速提高,训练结束时达到98%。 但是当我使用经过训练的模型来预测测试图像时,它总是输出相同的标签。 我也尝试过vgg16。效果很好
这是我的代码
def inference(inputs,is_training=True):
with tf.contrib.slim.arg_scope(mobilenet_v2.training_scope(is_training=is_training)):
net, endpoints = mobilenet_v2.mobilenet(input_tensor = inputs,num_classes = n_class,conv_defs = V2_18_DEF)
print('mobilenet output',net.get_shape().as_list())
return net
def loss(logit,label):
losses=[]
with tf.name_scope('LOSS'):
class_loss = tf.nn.softmax_cross_entropy_with_logits(labels=label,logits=logit)
class_loss = tf.reduce_mean(class_loss,axis = 0)
tf.summary.scalar('class_loss',class_loss)
regularization_loss = tf.reduce_sum(tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES))
total_loss = class_loss + regularization_loss
def train_op(loss):
train_step = tf.train.AdamOptimizer(1e-4).minimize(loss)#
return train_step
def test():
inputs = tf.placeholder(name='inputs',shape=[None,224,224,3],dtype=tf.float32)
label_holder = tf.placeholder(name='label_holder',shape=[None,n_class],dtype=tf.float32)
if(net_type=='mobile_v2'):
test_logits = inference(inputs,is_training=False)
elif(net_type=='mobile_v1'):
test_logits = inference_mobile_v1(inputs,is_training=False)
elif net_type=='vgg16':
test_logits = inference_vgg(inputs,is_training=False)
predict = tf.nn.softmax(test_logits)
predict_result = tf.argmax(predict,axis=1,output_type=tf.int32)
true_result = tf.argmax(label_holder,axis=1,output_type=tf.int32)
correct_predict = tf.equal(tf.argmax(predict,axis=1,output_type=tf.int32),tf.argmax(label_holder,axis=1,\
output_type = tf.int32))
accuracy = tf.reduce_mean(tf.cast(correct_predict,tf.float32))
pos = 0
acc = 0
count = 0
sess = tf.Session()
ckpt_path=''
try:
ckpt_path = sys.argv[3]
except:
ckpt = tf.train.get_checkpoint_state(output_path)
if(ckpt and ckpt.model_checkpoint_path):
ckpt_path = ckpt.model_checkpoint_path
sess.run(tf.global_variables_initializer())
variables_to_restore = slim.get_variables_to_restore()
for var in variables_to_restore:
print(var.name)
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess,ckpt_path)
dataset = Dataset(classes,0.8)
test_data,test_label = dataset.get_test_data()
while(pos<len(test_data)):
start = pos
end = min(pos+batch_size,len(test_data))
batch_img = test_data[start:end]
batch_label = test_label[start:end]
batch_data={}
batch_data[inputs] = batch_img
batch_data[label_holder] = batch_label
batch_acc,p_result,t_result = sess.run([accuracy,predict_result,true_result],feed_dict=batch_data)
print('batch_acc',batch_acc)
print(p_result)
print(t_result)
acc += batch_acc
pos = end
count+=1
acc = acc/count
print('test acc',acc)
sess.close()
def train():
inputs = tf.placeholder(name='inputs',shape=[None,224,224,3],dtype=tf.float32)#
labels_placeholder = {}
label_holder = tf.placeholder(name='label_holder',shape=[None,n_class],dtype=tf.float32) #
if(net_type=='mobile_v2'):
train_logits = inference(inputs)
elif(net_type=='mobile_v1'):
train_logits = inference_mobile_v1(inputs,is_training=True)
elif net_type=='vgg16':
train_logits = inference_vgg(inputs,is_training=True)
loss_op = loss(train_logits,label_holder)
predict = tf.nn.softmax(train_logits)
print('predict shape',predict.get_shape().as_list())
predict_result = tf.argmax(predict,axis=1,output_type = tf.int32)
correct_predict = tf.equal(tf.argmax(predict,axis=1,output_type=tf.int32),tf.argmax(label_holder,axis=1,\
output_type = tf.int32))
accuracy = tf.reduce_mean(tf.cast(correct_predict,tf.float32))
for var in tf.trainable_variables():
tf.summary.histogram(var.name,var)
train_step = train_op(loss_op)
#pretrain_restore
all_variable = tf.trainable_variables()
pretrain_vals=[]
reader = pywrap_tensorflow.NewCheckpointReader(pretrain_model_path)
var_to_shape_map=reader.get_variable_to_shape_map()
for var in all_variable:
print(var.name)
if('Logits' in var.name):
continue
if(var.name.split(':')[0] in var_to_shape_map):
print('restore',var.name)
pretrain_vals.append(var)
pretrain_saver = tf.train.Saver(pretrain_vals)
variables_to_restore = slim.get_variables_to_restore()
train_saver = tf.train.Saver(variables_to_restore)
#tensorboard
sess = tf.Session()
merge_op = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter('./tensorboard/train',sess.graph)
#init
print('init & restore')
sess.run(tf.global_variables_initializer())
pretrain_saver.restore(sess,pretrain_model_path)
dataset = Dataset(classes,0.8)#
data_amount = dataset.data_amount
each_epoch = int(data_amount/batch_size)
total_step = int(each_epoch*EPOCH)
global_step = 0
test_data,test_label = dataset.next_batch(80)
for epoch in range(EPOCH):
pos = 0
count = 0
acc = 0
while(pos<len(test_data)):
start = pos
end = min(pos+batch_size,len(test_data))
batch_img = test_data[start:end]
batch_label = test_label[start:end]
batch_data={}
batch_data[inputs] = batch_img
batch_data[label_holder] = batch_label
batch_acc,t_result ,loss_val= sess.run([accuracy,predict_result,loss_op],feed_dict=batch_data)
print('batch_acc',batch_acc,loss_val)
print(t_result)
acc += batch_acc
pos = end
count+=1
acc = acc/count
print('=====test_acc===',acc)
for epoch_step in range(each_epoch):
batch_img,batch_label = dataset.next_batch(batch_size)#
batch_data={}
batch_data[inputs] = batch_img
batch_data[label_holder] = batch_label
merge_str,loss_val,acc,p_result= sess.run([merge_op,loss_op,accuracy,predict_result],feed_dict=batch_data)
sess.run(train_step,feed_dict = batch_data)
print('loss %f,acc %f, global step %d ,epoch %d,epoch_step %d' % (loss_val,acc,global_step,epoch,epoch_step) )
print(p_result)
summary_writer.add_summary(merge_str,global_step=global_step)
summary_writer.flush()
global_step+=1
save_path = os.path.join(output_path,'model.ckpt')
train_saver.save(sess,save_path,global_step)