我在训练时使用了tf.train.shuffle_batch
,并且没有使用任何占位符,但是当我将ckpt冻结到pb文件中时,虽然可以将东西输入shuffle_batch
张量,但是我没有得到任何输入张量,但它需要Feed的大小与shuffle_batch
数据相同。如何解决?我知道我可以重写网络并还原参数,然后冻结,但这不明智吗?
火车
import tensorflow as tf
import tensorlayer as tl
from tensorlayer.layers import *
import os
trainsize=35680
testsize=889
batch_size=32
inputW=224
inputH=480
TRAIN_TFRECORD='./train.tfrecords'
TEST_TFRECORD='./test.tfrecords'
BATCH_CAPACITY=512
MIN_AFTER_DEQU=256
MAX_Cycle=100000
TRAIN_CYCLE=int(trainsize/batch_size)
TEST_CYCLE=int(testsize/batch_size)
learning_rt = 0.001
savepath='./ckpt/'
logpath='./logs/'
def network(inputs,is_train,reuse):
BITW =8
BITA=8
Decay=0.99
Epsi=1e-5
with tf.variable_scope('Model',reuse=reuse):
net=InputLayer(inputs,name='input') #224*480
net=QuanConv2dWithBN(net,32,(3,3),(1,1),'SAME',tf.nn.relu, decay=Decay, epsilon=Epsi, is_train=is_train, bitW=BITW,bitA=BITA,name='Conv1_1')
net=QuanConv2dWithBN(net,64,(3,3),(2,2),'SAME',tf.nn.relu, decay=Decay, epsilon=Epsi, is_train=is_train, bitW=BITW,bitA=BITA,name='Conv1_2') #112*240
net=QuanConv2dWithBN(net,64,(3,3),(1,1),'SAME',tf.nn.relu,decay=Decay, epsilon=Epsi, is_train=is_train, bitW=BITW,bitA=BITA,name='Conv2_1')
net=QuanConv2dWithBN(net,128,(3,3),(2,2),'SAME',tf.nn.relu,decay=Decay, epsilon=Epsi, is_train=is_train, bitW=BITW,bitA=BITA,name='Conv2_2') #56*120
net=QuanConv2dWithBN(net,128,(3,3),(1,1),'SAME',tf.nn.relu,decay=Decay, epsilon=Epsi, is_train=is_train, bitW=BITW,bitA=BITA,name='Conv3_1')
net=QuanConv2dWithBN(net,64,(1,1),(1,1),'SAME',tf.nn.relu,decay=Decay, epsilon=Epsi, is_train=is_train, bitW=BITW,bitA=BITA,name='Conv3_2')
net=QuanConv2dWithBN(net,128,(3,3),(2,2),'SAME',tf.nn.relu,decay=Decay, epsilon=Epsi, is_train=is_train, bitW=BITW,bitA=BITA,name='Conv3_3') #28*60
net=QuanConv2dWithBN(net,64,(3,3),(1,1),'SAME',tf.nn.relu,decay=Decay, epsilon=Epsi, is_train=is_train, bitW=BITW,bitA=BITA,name='Conv4_1')
net=QuanConv2dWithBN(net,96,(3,3),(2,2),'VALID',tf.nn.relu,decay=Decay, epsilon=Epsi, is_train=is_train, bitW=BITW,bitA=BITA,name='Conv4_2') #14*30
print(net.outputs)
net=QuanConv2dWithBN(net,128,(3,3),(2,2),'SAME',tf.nn.relu,decay=Decay, epsilon=Epsi, is_train=is_train, bitW=BITW,bitA=BITA,name='Conv5_1') #7*30
net=QuanConv2dWithBN(net,128,(3,3),(1,2),'VALID',tf.nn.relu,decay=Decay, epsilon=Epsi, is_train=is_train, bitW=BITW,bitA=BITA,name='Conv5_2') #3*30
net=QuanConv2d(net,128,(3,3),(1,2),'VALID',tf.nn.leaky_relu,bitW=BITW,bitA=BITA,name='Conv5_3') #1*30
print(net.outputs)
net=FlattenLayer(net,name='flat1')
net=QuanDenseLayer(net,128,act=tf.nn.leaky_relu,bitW=BITW,bitA=BITA,name='dense1')
net=DropoutLayer(net,0.5,is_fix=True,is_train=is_train,name='drop1')
net=DenseLayer(net,1,name='dense2')
outnet=net
volcume=net.outputs
print(volcume)
return outnet,net.outputs,volcume
def inference(inputs,is_train,reuse):
return network(inputs,is_train,reuse)
def read_and_decode(filename):
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'img': tf.FixedLenFeature([], tf.string),
'num' : tf.FixedLenFeature([], tf.float32),
})
img = tf.decode_raw(features['img'], tf.uint8)
img = tf.reshape(img, [480, 240, 3])
img=tf.random_crop(img,[480,224,3])
img = tf.image.random_brightness(img, max_delta=0.3)
img = tf.image.random_contrast(img, lower=0.1, upper=0.5)
# img = tf.image.random_hue(img, max_delta=0.1)
# img = tf.image.random_saturation(img, lower=0, upper=2.5)
img = tf.image.per_image_standardization(img)
label = tf.reshape( tf.cast(features['num'], tf.float32)*(1./230.)-0.5,[1])
return img, label
def read_and_decode_test(filename):
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example,
features={
'img': tf.FixedLenFeature([], tf.string),
'num' : tf.FixedLenFeature([], tf.float32),
})
img = tf.decode_raw(features['img'], tf.uint8)
img = tf.reshape(img, [480, 240, 3])
img=img[:,8:232,:]
img = tf.image.per_image_standardization(img)
label = tf.reshape( tf.cast(features['num'], tf.float32)*(1./230.)-0.5,[1])
return img, label
def smooth_L1(x):
return tf.where(tf.less_equal(tf.abs(x), 1.0), tf.multiply(0.5, tf.pow(x, 2.0)), tf.subtract(tf.abs(x), 0.5))
def cal_loss(logits,labels):
# return tf.clip_by_value(tf.reduce_mean(tf.losses.mean_squared_error(labels,logits)) ,0.000001,10000000.)
return tf.reduce_mean(tf.where(tf.less_equal(tf.abs(logits-labels), 0.02),0.00001*tf.ones_like(logits-labels), tf.multiply(1., tf.pow(logits-labels, 2.0))))
# return tf.clip_by_value(tf.reduce_sum(smooth_L1(labels-logits)),0.0000001,100.)
def cal_acc(logits,labels):
return tf.reduce_mean( tf.cast( tf.less_equal(tf.abs(labels-logits),tf.ones_like(labels)*.1),tf.float32))
if __name__ == '__main__':
img_train,num_train = read_and_decode(TRAIN_TFRECORD)
img_test,num_test = read_and_decode(TEST_TFRECORD)
img_train_batch, num_train_batch = tf.train.shuffle_batch(
[img_train, num_train], batch_size=batch_size, capacity=BATCH_CAPACITY,
min_after_dequeue=MIN_AFTER_DEQU)
img_test_batch, num_test_batch = tf.train.batch(
[img_test,num_test], batch_size=batch_size)
net,_,logits_train=inference(img_train_batch,True,None)
_,_,logits_test=inference(img_test_batch,False,True)
loss_train=cal_loss(logits_train,num_train_batch)
loss_test=cal_loss(logits_test,num_test_batch)
acc_test=cal_acc(logits_test,num_test_batch)
acc_train=cal_acc(logits_train,num_train_batch)
global_step=tf.train.create_global_step()
#tf.train.get_global_step()
learning_rate=tf.train.exponential_decay(learning_rt, global_step,
5000, 0.9, staircase=True)
train = tf.train.MomentumOptimizer(learning_rate,momentum=0.9).minimize(loss_train,global_step=global_step)
# train = tf.train.AdamOptimizer(learning_rt).minimize(loss_train)
tf.summary.scalar('loss_train', loss_train)
tf.summary.scalar('acc_train', acc_train)
merged = tf.summary.merge_all()
with tf.Session(config=tf.ConfigProto()) as sess:
trainwrite = tf.summary.FileWriter(logpath, sess.graph)
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(sess=sess, coord=coord)
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
run_cycle=0
if os.path.exists(savepath+'313.ckpt.index') :
print('\nStart Restore')
saver.restore(sess, savepath+'33.ckpt')
print('\nEnd Restore')
print('\nStart Training')
try:
while not coord.should_stop():
while run_cycle < MAX_Cycle:
run_cycle+=1
# if run_cycle%10==0:
# learning_rt*=0.6
# if run_cycle%200==0:
# learning_rt*=2.
l_tall=0
a_tall=0
l_teall=0
a_teall=0
for train_c in range(TRAIN_CYCLE):
_,l_train,a_train=sess.run([train,loss_train,acc_train])
l_tall+=l_train
a_tall+=a_train
if (train_c+1)%100==0:
print('train_loss:%f'%(l_tall/100.))
print('train_acc:%f'%(a_tall/100.))
l_tall = 0
a_tall = 0
if (train_c+1)%500==0:
print('Global Step:',sess.run(global_step))
result_merged=sess.run(merged)
trainwrite.add_summary(result_merged, run_cycle*TRAIN_CYCLE+train_c)
for test_c in range(TEST_CYCLE):
l_test,a_test=sess.run([loss_test,acc_test])
l_teall+=l_test
a_teall+=a_test
if (test_c+1)%TEST_CYCLE==0:
print('------------------')
print('test_loss:%f'%(l_teall/TEST_CYCLE))
print('test_acc:%f'%(a_teall/TEST_CYCLE))
print('------------------')
l_teall = 0
l_teall = 0
saver.save(sess, savepath+ str(run_cycle) + '.ckpt')
except tf.errors.OutOfRangeError:
print('Done training!!!')
finally:
# When done, ask the threads to stop.
coord.request_stop()
coord.join(threads)
sess.close()
冻结代码
导入操作系统,argparse
import tensorflow as tf
from tensorflow.python.framework import graph_util
dir = os.path.dirname(os.path.realpath(__file__))
def freeze_graph(model_folder, output_nodes='y_hat',
output_filename='frozen-graph.pb',
rename_outputs=None):
# Load checkpoint
checkpoint = tf.train.get_checkpoint_state(model_folder)
input_checkpoint = checkpoint.model_checkpoint_path
output_graph = output_filename
# Devices should be cleared to allow Tensorflow to control placement of
# graph when loading on different machines
saver = tf.train.import_meta_graph(input_checkpoint + '.meta',
clear_devices=True)
graph = tf.get_default_graph()
onames = output_nodes.split(',')
# https://stackoverflow.com/a/34399966/4190475
if rename_outputs is not None:
nnames = rename_outputs.split(',')
with graph.as_default():
for o, n in zip(onames, nnames):
_out = tf.identity(graph.get_tensor_by_name(o + ':0'), name=n)
onames = nnames
input_graph_def = graph.as_graph_def()
# fix batch norm nodes
for node in input_graph_def.node:
if node.op == 'RefSwitch':
node.op = 'Switch'
for index in range(len(node.input)):
if 'moving_' in node.input[index]:
node.input[index] = node.input[index] + '/read'
elif node.op == 'AssignSub':
node.op = 'Sub'
if 'use_locking' in node.attr: del node.attr['use_locking']
with tf.Session(graph=graph) as sess:
saver.restore(sess, input_checkpoint)
# In production, graph weights no longer need to be updated
# graph_util provides utility to change all variables to constants
output_graph_def = graph_util.convert_variables_to_constants(
sess, input_graph_def,
onames # unrelated nodes will be discarded
)
# Serialize and write to file
with tf.gfile.GFile(output_graph, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node))
if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Prune and freeze weights from checkpoints into production models')
parser.add_argument("--checkpoint_path",
default='./regressionDir/',
type=str, help="Path to checkpoint files")
parser.add_argument("--output_nodes",
default='Model/dense2/bias_add',
type=str, help="Names of output node, comma seperated")
parser.add_argument("--output_graph",
default='reg.pb',
type=str, help="Output graph filename")
parser.add_argument("--rename_outputs",
default='out_vol',
type=str, help="Rename output nodes for better \
readability in production graph, to be specified in \
the same order as output_nodes")
args = parser.parse_args()
freeze_graph(args.checkpoint_path, args.output_nodes, args.output_graph, args.rename_outputs)
测试推断代码
将tensorflow导入为tf 将numpy导入为np 从PIL导入图片 导入时间
gf = tf.GraphDef()
gf.ParseFromString(open('reg.pb', 'rb').read())
print([n.name + '=>' + n.op for n in gf.node])
output_graph_path = './reg.pb'
with tf.Session() as sess:
tf.global_variables_initializer().run()
output_graph_def = tf.GraphDef()
with open(output_graph_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
input_img = sess.graph.get_tensor_by_name("shuffle_batch:0")
print(input_img)
out_vol = sess.graph.get_tensor_by_name("out_vol:0")
out_voln = sess.graph.get_tensor_by_name("shuffle_batch:0")
a = np.random.random([48, 480, 224, 3])
for x in range(100):
ntime1 = time.time()
vol = sess.run(out_vol, {input_img: a})
ntime2 = time.time()
print(ntime2 - ntime1)