如何使用queuerunner冻结TensorFlow ckpt?

时间:2018-11-28 11:59:12

标签: tensorflow

我在训练时使用了tf.train.shuffle_batch,并且没有使用任何占位符,但是当我将ckpt冻结到pb文件中时,虽然可以将东西输入shuffle_batch张量,但是我没有得到任何输入张量,但它需要Feed的大小与shuffle_batch数据相同。如何解决?我知道我可以重写网络并还原参数,然后冻结,但这不明智吗?

火车

import tensorflow as tf
import tensorlayer as tl
from tensorlayer.layers import *
import os
trainsize=35680
testsize=889

batch_size=32
inputW=224
inputH=480
TRAIN_TFRECORD='./train.tfrecords'
TEST_TFRECORD='./test.tfrecords'
BATCH_CAPACITY=512
MIN_AFTER_DEQU=256
MAX_Cycle=100000
TRAIN_CYCLE=int(trainsize/batch_size)
TEST_CYCLE=int(testsize/batch_size)
learning_rt = 0.001
savepath='./ckpt/'
logpath='./logs/'


def network(inputs,is_train,reuse):
    BITW =8
    BITA=8
    Decay=0.99
    Epsi=1e-5
    with tf.variable_scope('Model',reuse=reuse):
        net=InputLayer(inputs,name='input') #224*480

        net=QuanConv2dWithBN(net,32,(3,3),(1,1),'SAME',tf.nn.relu, decay=Decay, epsilon=Epsi,  is_train=is_train,  bitW=BITW,bitA=BITA,name='Conv1_1')

        net=QuanConv2dWithBN(net,64,(3,3),(2,2),'SAME',tf.nn.relu, decay=Decay, epsilon=Epsi,  is_train=is_train,  bitW=BITW,bitA=BITA,name='Conv1_2') #112*240

        net=QuanConv2dWithBN(net,64,(3,3),(1,1),'SAME',tf.nn.relu,decay=Decay, epsilon=Epsi,  is_train=is_train,  bitW=BITW,bitA=BITA,name='Conv2_1')
        net=QuanConv2dWithBN(net,128,(3,3),(2,2),'SAME',tf.nn.relu,decay=Decay, epsilon=Epsi,  is_train=is_train,  bitW=BITW,bitA=BITA,name='Conv2_2') #56*120

        net=QuanConv2dWithBN(net,128,(3,3),(1,1),'SAME',tf.nn.relu,decay=Decay, epsilon=Epsi,  is_train=is_train,  bitW=BITW,bitA=BITA,name='Conv3_1')

        net=QuanConv2dWithBN(net,64,(1,1),(1,1),'SAME',tf.nn.relu,decay=Decay, epsilon=Epsi,  is_train=is_train,  bitW=BITW,bitA=BITA,name='Conv3_2')

        net=QuanConv2dWithBN(net,128,(3,3),(2,2),'SAME',tf.nn.relu,decay=Decay, epsilon=Epsi,  is_train=is_train,  bitW=BITW,bitA=BITA,name='Conv3_3') #28*60

        net=QuanConv2dWithBN(net,64,(3,3),(1,1),'SAME',tf.nn.relu,decay=Decay, epsilon=Epsi,  is_train=is_train,  bitW=BITW,bitA=BITA,name='Conv4_1')

        net=QuanConv2dWithBN(net,96,(3,3),(2,2),'VALID',tf.nn.relu,decay=Decay, epsilon=Epsi,  is_train=is_train,  bitW=BITW,bitA=BITA,name='Conv4_2') #14*30

        print(net.outputs)
        net=QuanConv2dWithBN(net,128,(3,3),(2,2),'SAME',tf.nn.relu,decay=Decay, epsilon=Epsi,  is_train=is_train,  bitW=BITW,bitA=BITA,name='Conv5_1') #7*30

        net=QuanConv2dWithBN(net,128,(3,3),(1,2),'VALID',tf.nn.relu,decay=Decay, epsilon=Epsi,  is_train=is_train,  bitW=BITW,bitA=BITA,name='Conv5_2') #3*30

        net=QuanConv2d(net,128,(3,3),(1,2),'VALID',tf.nn.leaky_relu,bitW=BITW,bitA=BITA,name='Conv5_3') #1*30
        print(net.outputs)
        net=FlattenLayer(net,name='flat1')
        net=QuanDenseLayer(net,128,act=tf.nn.leaky_relu,bitW=BITW,bitA=BITA,name='dense1')

        net=DropoutLayer(net,0.5,is_fix=True,is_train=is_train,name='drop1')
        net=DenseLayer(net,1,name='dense2')
        outnet=net
        volcume=net.outputs
        print(volcume)
        return outnet,net.outputs,volcume


def inference(inputs,is_train,reuse):
    return network(inputs,is_train,reuse)



def read_and_decode(filename):
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'img': tf.FixedLenFeature([], tf.string),
                                           'num' : tf.FixedLenFeature([], tf.float32),
                                       })

    img = tf.decode_raw(features['img'], tf.uint8)
    img = tf.reshape(img, [480, 240, 3])
    img=tf.random_crop(img,[480,224,3])
    img = tf.image.random_brightness(img, max_delta=0.3)
    img = tf.image.random_contrast(img, lower=0.1, upper=0.5)
    # img = tf.image.random_hue(img, max_delta=0.1)
    # img = tf.image.random_saturation(img, lower=0, upper=2.5)
    img = tf.image.per_image_standardization(img)
    label = tf.reshape( tf.cast(features['num'], tf.float32)*(1./230.)-0.5,[1])
    return img, label

def read_and_decode_test(filename):
    filename_queue = tf.train.string_input_producer([filename])
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    features = tf.parse_single_example(serialized_example,
                                       features={
                                           'img': tf.FixedLenFeature([], tf.string),
                                           'num' : tf.FixedLenFeature([], tf.float32),
                                       })

    img = tf.decode_raw(features['img'], tf.uint8)
    img = tf.reshape(img, [480, 240, 3])
    img=img[:,8:232,:]
    img = tf.image.per_image_standardization(img)
    label = tf.reshape( tf.cast(features['num'], tf.float32)*(1./230.)-0.5,[1])
    return img, label


def smooth_L1(x):
    return tf.where(tf.less_equal(tf.abs(x), 1.0), tf.multiply(0.5, tf.pow(x, 2.0)), tf.subtract(tf.abs(x), 0.5))

def cal_loss(logits,labels):
    # return tf.clip_by_value(tf.reduce_mean(tf.losses.mean_squared_error(labels,logits)) ,0.000001,10000000.)
    return tf.reduce_mean(tf.where(tf.less_equal(tf.abs(logits-labels), 0.02),0.00001*tf.ones_like(logits-labels),  tf.multiply(1., tf.pow(logits-labels, 2.0))))
    # return tf.clip_by_value(tf.reduce_sum(smooth_L1(labels-logits)),0.0000001,100.)

def cal_acc(logits,labels):
    return tf.reduce_mean( tf.cast( tf.less_equal(tf.abs(labels-logits),tf.ones_like(labels)*.1),tf.float32))

if __name__ == '__main__':

    img_train,num_train  = read_and_decode(TRAIN_TFRECORD)
    img_test,num_test  = read_and_decode(TEST_TFRECORD)
    img_train_batch,  num_train_batch = tf.train.shuffle_batch(
        [img_train, num_train], batch_size=batch_size, capacity=BATCH_CAPACITY,
        min_after_dequeue=MIN_AFTER_DEQU)
    img_test_batch,  num_test_batch = tf.train.batch(
        [img_test,num_test], batch_size=batch_size)
    net,_,logits_train=inference(img_train_batch,True,None)
    _,_,logits_test=inference(img_test_batch,False,True)
    loss_train=cal_loss(logits_train,num_train_batch)
    loss_test=cal_loss(logits_test,num_test_batch)
    acc_test=cal_acc(logits_test,num_test_batch)
    acc_train=cal_acc(logits_train,num_train_batch)
    global_step=tf.train.create_global_step()
    #tf.train.get_global_step()
    learning_rate=tf.train.exponential_decay(learning_rt, global_step,
                                           5000, 0.9, staircase=True)
    train = tf.train.MomentumOptimizer(learning_rate,momentum=0.9).minimize(loss_train,global_step=global_step)
    # train = tf.train.AdamOptimizer(learning_rt).minimize(loss_train)


    tf.summary.scalar('loss_train', loss_train)
    tf.summary.scalar('acc_train', acc_train)
    merged = tf.summary.merge_all()
    with tf.Session(config=tf.ConfigProto()) as sess:
        trainwrite = tf.summary.FileWriter(logpath, sess.graph)
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(sess=sess, coord=coord)
        sess.run(tf.global_variables_initializer())
        saver = tf.train.Saver()
        run_cycle=0
        if os.path.exists(savepath+'313.ckpt.index') :
            print('\nStart Restore')
            saver.restore(sess, savepath+'33.ckpt')
            print('\nEnd Restore')
        print('\nStart Training')
        try:

            while not coord.should_stop():

                while run_cycle < MAX_Cycle:

                    run_cycle+=1
                    # if run_cycle%10==0:
                    #     learning_rt*=0.6
                    # if run_cycle%200==0:
                    #     learning_rt*=2.
                    l_tall=0
                    a_tall=0

                    l_teall=0
                    a_teall=0

                    for train_c in range(TRAIN_CYCLE):


                        _,l_train,a_train=sess.run([train,loss_train,acc_train])
                        l_tall+=l_train
                        a_tall+=a_train

                        if (train_c+1)%100==0:
                            print('train_loss:%f'%(l_tall/100.))
                            print('train_acc:%f'%(a_tall/100.))
                            l_tall = 0
                            a_tall = 0
                        if (train_c+1)%500==0:
                            print('Global Step:',sess.run(global_step))
                            result_merged=sess.run(merged)
                            trainwrite.add_summary(result_merged, run_cycle*TRAIN_CYCLE+train_c)
                    for test_c in range(TEST_CYCLE):
                        l_test,a_test=sess.run([loss_test,acc_test])
                        l_teall+=l_test
                        a_teall+=a_test
                        if (test_c+1)%TEST_CYCLE==0:
                            print('------------------')
                            print('test_loss:%f'%(l_teall/TEST_CYCLE))
                            print('test_acc:%f'%(a_teall/TEST_CYCLE))
                            print('------------------')
                            l_teall = 0
                            l_teall = 0
                    saver.save(sess, savepath+ str(run_cycle) + '.ckpt')




        except tf.errors.OutOfRangeError:
            print('Done training!!!')
        finally:
            # When done, ask the threads to stop.
            coord.request_stop()

        coord.join(threads)
        sess.close()

冻结代码

导入操作系统,argparse

import tensorflow as tf
from tensorflow.python.framework import graph_util

dir = os.path.dirname(os.path.realpath(__file__))


def freeze_graph(model_folder, output_nodes='y_hat',
                 output_filename='frozen-graph.pb',
                 rename_outputs=None):
    # Load checkpoint
    checkpoint = tf.train.get_checkpoint_state(model_folder)
    input_checkpoint = checkpoint.model_checkpoint_path

    output_graph = output_filename

    # Devices should be cleared to allow Tensorflow to control placement of
    # graph when loading on different machines
    saver = tf.train.import_meta_graph(input_checkpoint + '.meta',
                                       clear_devices=True)

    graph = tf.get_default_graph()
    onames = output_nodes.split(',')

    # https://stackoverflow.com/a/34399966/4190475
    if rename_outputs is not None:
        nnames = rename_outputs.split(',')
        with graph.as_default():
            for o, n in zip(onames, nnames):
                _out = tf.identity(graph.get_tensor_by_name(o + ':0'), name=n)
            onames = nnames

    input_graph_def = graph.as_graph_def()

    # fix batch norm nodes
    for node in input_graph_def.node:

        if node.op == 'RefSwitch':
            node.op = 'Switch'
            for index in range(len(node.input)):
                if 'moving_' in node.input[index]:
                    node.input[index] = node.input[index] + '/read'
        elif node.op == 'AssignSub':
            node.op = 'Sub'
            if 'use_locking' in node.attr: del node.attr['use_locking']

    with tf.Session(graph=graph) as sess:
        saver.restore(sess, input_checkpoint)

        # In production, graph weights no longer need to be updated
        # graph_util provides utility to change all variables to constants
        output_graph_def = graph_util.convert_variables_to_constants(
            sess, input_graph_def,
            onames  # unrelated nodes will be discarded
        )

        # Serialize and write to file
        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." % len(output_graph_def.node))


if __name__ == '__main__':
    parser = argparse.ArgumentParser(
        description='Prune and freeze weights from checkpoints into production models')
    parser.add_argument("--checkpoint_path",
                        default='./regressionDir/',
                        type=str, help="Path to checkpoint files")
    parser.add_argument("--output_nodes",
                        default='Model/dense2/bias_add',
                        type=str, help="Names of output node, comma seperated")
    parser.add_argument("--output_graph",
                        default='reg.pb',
                        type=str, help="Output graph filename")
    parser.add_argument("--rename_outputs",
                        default='out_vol',
                        type=str, help="Rename output nodes for better \
                            readability in production graph, to be specified in \
                            the same order as output_nodes")
    args = parser.parse_args()

    freeze_graph(args.checkpoint_path, args.output_nodes, args.output_graph, args.rename_outputs)

测试推断代码

将tensorflow导入为tf     将numpy导入为np     从PIL导入图片     导入时间

gf = tf.GraphDef()
gf.ParseFromString(open('reg.pb', 'rb').read())
print([n.name + '=>' + n.op for n in gf.node])

output_graph_path = './reg.pb'
with tf.Session() as sess:
    tf.global_variables_initializer().run()
    output_graph_def = tf.GraphDef()
    with open(output_graph_path, "rb") as f:
        output_graph_def.ParseFromString(f.read())
        _ = tf.import_graph_def(output_graph_def, name="")

    input_img = sess.graph.get_tensor_by_name("shuffle_batch:0")
    print(input_img)
    out_vol = sess.graph.get_tensor_by_name("out_vol:0")
    out_voln = sess.graph.get_tensor_by_name("shuffle_batch:0")
    a = np.random.random([48, 480, 224, 3])
    for x in range(100):
        ntime1 = time.time()
        vol = sess.run(out_vol, {input_img: a})
        ntime2 = time.time()
        print(ntime2 - ntime1)

0 个答案:

没有答案