Tensorflow保存/加载冻结的tf.graph并在加载的图上运行分类

时间:2018-10-10 13:53:48

标签: python tensorflow

正如我在主题中提到的,我想将tf.graph保存到Frozen_graph.pb文件中。稍后可以节省空间,我将尝试在jetson tx2上运行它。我做了一个简短的MNIST示例来描述我的问题。我在python 3.5上运行TF 1.7。

问题1:据我了解,freeze_graph方法采用一个检查点文件,将所有变量都转换为常量,但我用第二个参数定义的变量除外。当我尝试获取正确的张量名称时,我写了loggits.name,但出现错误,在图中找不到该名称的张量。

问题2:在那之后,我将能够提取冻结的图,如何将其重新加载并对其进行分类。

我的代码已附加,应该可以在一个py文件中使用。 预先非常感谢

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
import os
import time

import tensorflow as tf
import os
import argparse


#METHODS I WANT TO TEST
#TAKE THE CHECKPOINT FILE AND DELETE ALL NOTES THAT ARE NOT USEFUL
def freeze_graph(checkpoint_directory,output_node_names):
    #checkpoint = tf.train.get_checkpoint_state(checkpoint_directory)
    print(checkpoint_directory)
    checkpoint = tf.train.get_checkpoint_state(checkpoint_directory)
    input_checkpoint = checkpoint.model_checkpoint_path
    absolute_model_dir = str(os.sep).join(input_checkpoint.split(os.sep)[:-1])
    output_graph = absolute_model_dir + "/frozen_model.pb"
    clear_devices = True

    with tf.Session(graph = tf.Graph()) as sess:
        #import the metagraph in default graph
        saver = tf.train.import_meta_graph(input_checkpoint + '.meta',clear_devices=clear_devices)

        #restore the weights
        saver.restore(sess,input_checkpoint)

        #wrap variables to constants
        [print(n.name) for n in tf.get_default_graph().as_graph_def().node]
        output_graph_def = tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(),output_node_names.split(","))

        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." %len(output_graph_def.node))

    return output_graph_def



#HERE IS THE METHOD THAT ALLOWS ME TO LOAD MY FROZEN GRAPH AS GRAPH
def load_graph(frozen_graph_filename):
    with tf.gfile.GFile(frozen_graph_filename,"rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name = "prefix")
    return graph


#get the data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)

#NETWORK PARAMETERS
learning_rate = 0.01

dropout = 0.75
display_step = 1
filter_height = 5
filter_width = 5
depth_in = 1
depth_out1 = 64
depth_out2 = 128

#PARAMETERS OF THE DATASET
input_height = 28
input_width = 28
n_classes = 10

#TRAINING PARAMETERS
epochs = 1
batch_size = 256
num_batches = int(mnist.train.num_examples/batch_size)

x = tf.placeholder(tf.float32,[None,28*28],name = "input")
y = tf.placeholder(tf.float32,[None,n_classes])
keep_prob = tf.placeholder(tf.float32)

weights = {'wc1': tf.Variable(tf.random_normal([filter_height,filter_width,depth_in,depth_out1])),
           'wc2': tf.Variable(tf.random_normal([filter_height, filter_width, depth_out1, depth_out2])),
           'wd1': tf.Variable(tf.random_normal([int(input_height/4)*int(input_height/4)*depth_out2,1024])),
           'out': tf.Variable(tf.random_normal([1024,n_classes]))}

biases = {'bc1': tf.Variable(tf.random_normal([depth_out1])),
          'bc2': tf.Variable(tf.random_normal([depth_out2])),
          'bd1': tf.Variable(tf.random_normal([1024])),
          'out': tf.Variable(tf.random_normal([n_classes]))}


#DEFINE YOUR NEURAL NETWORKS LAYER OPERATIONS
def ops_conv2d(x,W,b,strides = 1, add_bias = True, activation = tf.nn.relu, use_activation = True):

    x = tf.nn.conv2d(x,W,strides = [1,strides,strides,1],padding = 'SAME')
    x = tf.nn.bias_add(x,b)
    if use_activation:
        return activation(x)
    else:
        return x

def ops_maxpool2d(x,stride=2):
    return tf.nn.max_pool(x,ksize=[1,stride,stride,1],strides = [1,stride,stride,1], padding = 'SAME' )

def ops_dropout(input_fully_connected,dropout):
    return tf.nn.dropout(input_fully_connected,dropout)

def ops_fullyconnected(input, activation = tf.nn.relu, use_activation = True):
    fc = tf.reshape(input,[-1,weights['wd1'].get_shape().as_list()[0]])
    fc = tf.add(tf.matmul(fc,weights['wd1']),biases['bd1'])
    if use_activation:
        return activation(fc)
    else:
        return fc

#DEFINE NETWORK ARCHTEKTURE (FORWARDPASS)

def build_network(x,weights,biases,dropout):
    x = tf.reshape(x,shape=(-1,28,28,1))

    conv_layer_1 = ops_conv2d(x,weights['wc1'],biases['bc1'],activation=tf.nn.relu, use_activation=True)
    conv_layer_1 = ops_maxpool2d(conv_layer_1,2)

    conv_layer_2 = ops_conv2d(conv_layer_1,weights['wc2'],biases['bc2'],activation=tf.nn.relu, use_activation=True)
    conv_layer_2 = ops_maxpool2d(conv_layer_2,2)

    fc1 = ops_fullyconnected(conv_layer_2, activation=tf.nn.relu, use_activation=True)
    fc1 = ops_dropout(fc1,dropout)

    logits = tf.add(tf.matmul(fc1,weights['out']),biases['out'],name = "logits")

    return logits

#DEFINE TENSORFLOW BACKPROPAGATION OBJECTS (BACKWARDPASS)

logits = build_network(x,weights,biases,keep_prob)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = logits,labels = y))

#CHOSE AN OPTIMIZER
optimizer = tf.train.AdadeltaOptimizer(learning_rate=learning_rate).minimize(loss=loss)
predicted_labels = tf.equal(tf.argmax(logits,1),tf.argmax(y,1))

#EVALUATION PARAMETERS
acc = tf.reduce_mean(tf.cast(predicted_labels,tf.float32))

#NOW INITIALIZE ALL TF VARIABLES
init = tf.global_variables_initializer()

saver = tf.train.Saver(max_to_keep=10)



#NOW START THE SESSION AND EXECUTE THE GRAPH
with tf.Session() as sess:
    sess.run(init)

    for i in range(epochs):
        save_path = saver.save(sess, os.curdir + "checkpoints/MNIST_TEST.ckpt")
        for j in range(num_batches):
            batch_x, batch_y = mnist.train.next_batch(batch_size)
            sess.run(optimizer, feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})

            losses,accs = sess.run([loss,acc],feed_dict={x:batch_x,y:batch_y,keep_prob:1.})

            if epochs % display_step == 0:
                print("EPOCH:",'%04d' % (i+1),
                      "loss =", "{:.9f}".format(losses),
                      "acc =", "{:.5f}".format(accs))
    print("TRAINING COMPLETED")
    #START PREDICTIONS
    predicted_label = sess.run(logits,feed_dict={x:mnist.test.images[:256],keep_prob:1.})
    test_classes = np.argmax(predicted_label,1)
    print("TEST ACCURACY:",sess.run(acc,feed_dict={x:mnist.test.images[:256], y:mnist.test.labels[:256],keep_prob:1.}))
    f,a = plt.subplots(1,10,figsize = (10,2))

    for i in range(10):
        a[i].imshow(np.reshape(mnist.test.images[i],(28,28)))
        print( test_classes[i])

    print("TOTAL EXAMPLE FINNISHED")

    freeze_graph(os.curdir + "checkpoints" + os.sep, logits.name)

graph = load_graph(os.curdir + os.sep + "checkpoints" + os.sep + "frozen_model.pb")
with tf.Session(graph) as sess:
    sess.run(init)
    predicted_label = sess.run(logits, feed_dict={x: mnist.test.images[:256], keep_prob: 1.})
    print(predicted_label)

1 个答案:

答案 0 :(得分:1)

如果有人有相同的问题,这里描述我如何解决。 保存和加载数据:

首先请注意,我现在有一个不同的管道。首先,我将会话保存在一个保护程序(ckpt文件)中。之后,我构造了一个象形文字(graph.pb)。然后将该图转移到冻结图(frozen.pb)中。要加载冻结图,我使用load_frozen_graph_from_session方法。在这种方法中,我还测试了通过网络的前向通过。

在加载的图形上运行推断: 首先,我将张量命名为x(name =“ input”),这将导致一个张量名(“ input:0”) 因此,当您尝试在新会话中填充此占位符时,您需要predicted_label = sess.run("output:0", feed_dict={"input:0":mnist.test.images[:256], "keep_prob:0": 1.})

输出是logit,而不是网络内部的预测。这是因为如果您运行该会话,它将一直运行直到命中要获取的变量。我会接受预测吗,我还需要占位符用于y(名称=标签)。

这是完整的代码:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.examples.tutorials.mnist import input_data
import os
import time

import tensorflow as tf
import os
import argparse
from tensorflow.python.platform import gfile
from tensorflow.python.framework.graph_util import convert_variables_to_constants



#METHODS I WANT TO TEST
def freeze_graph_from_Session(sess,saver):
    # convert_variables_to_constants(sess, input_graph_def, output_node_names, variable_names_whitelist=None)
    save_graph(sess,saver)

    with gfile.FastGFile("./tmp/" + "graph.pb", 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    frozen_graph_def = convert_variables_to_constants(sess, graph_def, ["output"])

    with tf.gfile.GFile("./tmp/" + "frozen.pb", "wb") as f:
        f.write(frozen_graph_def.SerializeToString())

def save_graph(sess, saver):
    saver.save(sess, "./tmp/model", write_meta_graph=True, global_step=1)

    with open("./tmp/" + "graph.pb", 'wb') as f:
        f.write(sess.graph_def.SerializeToString())
    #sess.close()

def load_frozen_graph_from_session():
    filename = "./tmp/" + "frozen.pb"
    print("LOADING GRAPH")
    with tf.gfile.GFile(filename, "rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    print("OPEN GRAPH")
    with tf.Graph().as_default() as graph:
        print("DEFINE INPUT")
        new_input = tf.placeholder(tf.float32, [None, 28 * 28], name="new_input")
        print("DEFINE INPUT MAP")
        tf.import_graph_def(
            graph_def,
            # usually, during training you use queues, but at inference time use placeholders
            # this turns into "input
            input_map={"input:0": new_input},
            return_elements=None,
            # if input_map is not None, needs a name
            name="bla",
            op_dict=None,
            producer_op_list=None
        )

    checkpoint_path = tf.train.latest_checkpoint("./tmp/")

    with tf.Session(graph=graph) as sess:
        saver = tf.train.import_meta_graph(checkpoint_path + ".meta", import_scope=None)
        saver.restore(sess, checkpoint_path)
        print("TRY FORWARD RUN THROUGH LOADED GRAPH")

        predicted_label = sess.run("output:0", feed_dict={"input:0":mnist.test.images[:256], "keep_prob:0": 1.})
        print("output", predicted_label)
        f, a = plt.subplots(1, 10, figsize=(10, 2))
        test_classes = np.argmax(predicted_label, 1)
        for i in range(10):
            a[i].imshow(np.reshape(mnist.test.images[i], (28, 28)))
            print(test_classes[i])
        print ("output:", test_classes)


#TAKE THE CHECKPOINT FILE AND DELETE ALL NOTES THAT ARE NOT USEFUL
def freeze_graph(checkpoint_directory,output_node_names):
    #checkpoint = tf.train.get_checkpoint_state(checkpoint_directory)
    print(checkpoint_directory)
    checkpoint = tf.train.get_checkpoint_state(checkpoint_directory)
    input_checkpoint = checkpoint.model_checkpoint_path
    absolute_model_dir = str(os.sep).join(input_checkpoint.split(os.sep)[:-1])
    output_graph = absolute_model_dir + "/frozen_model.pb"
    clear_devices = True

    with tf.Session(graph = tf.Graph()) as sess:
        #import the metagraph in default graph
        saver = tf.train.import_meta_graph(input_checkpoint + '.meta',clear_devices=clear_devices)

        #restore the weights
        saver.restore(sess,input_checkpoint)

        #wrap variables to constants
        [print(n.name) for n in tf.get_default_graph().as_graph_def().node]
        output_graph_def = tf.graph_util.convert_variables_to_constants(sess, tf.get_default_graph().as_graph_def(),output_node_names.split(","))

        with tf.gfile.GFile(output_graph, "wb") as f:
            f.write(output_graph_def.SerializeToString())
        print("%d ops in the final graph." %len(output_graph_def.node))

    return output_graph_def

#HERE IS THE METHOD THAT ALLOWS ME TO LOAD MY FROZEN GRAPH AS GRAPH
def load_graph(frozen_graph_filename):
    with tf.gfile.GFile(frozen_graph_filename,"rb") as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name = "prefix")
    return graph


#get the data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
print(mnist.test.labels[:256])

print("load_freeze_graph_from_session: STARTED")
load_frozen_graph_from_session()
print("load_freeze_graph_from_session: ENDED")
exit()

#NETWORK PARAMETERS
learning_rate = 0.01

dropout = 0.75
display_step = 1
filter_height = 5
filter_width = 5
depth_in = 1
depth_out1 = 64
depth_out2 = 128

#PARAMETERS OF THE DATASET
input_height = 28
input_width = 28
n_classes = 10

#TRAINING PARAMETERS
epochs = 1
batch_size = 256
num_batches = int(mnist.train.num_examples/batch_size)

x = tf.placeholder(tf.float32,[None,28*28],name="input")

y = tf.placeholder(tf.float32,[None,n_classes],name = "label")
keep_prob = tf.placeholder(tf.float32,name = "keep_prob")

weights = {'wc1': tf.Variable(tf.random_normal([filter_height,filter_width,depth_in,depth_out1])),
           'wc2': tf.Variable(tf.random_normal([filter_height, filter_width, depth_out1, depth_out2])),
           'wd1': tf.Variable(tf.random_normal([int(input_height/4)*int(input_height/4)*depth_out2,1024])),
           'out': tf.Variable(tf.random_normal([1024,n_classes]))}

biases = {'bc1': tf.Variable(tf.random_normal([depth_out1])),
          'bc2': tf.Variable(tf.random_normal([depth_out2])),
          'bd1': tf.Variable(tf.random_normal([1024])),
          'out': tf.Variable(tf.random_normal([n_classes]))}


#DEFINE YOUR NEURAL NETWORKS LAYER OPERATIONS
def ops_conv2d(x,W,b,strides = 1, add_bias = True, activation = tf.nn.relu, use_activation = True):

    x = tf.nn.conv2d(x,W,strides = [1,strides,strides,1],padding = 'SAME')
    x = tf.nn.bias_add(x,b)
    if use_activation:
        return activation(x)
    else:
        return x

def ops_maxpool2d(x,stride=2):
    return tf.nn.max_pool(x,ksize=[1,stride,stride,1],strides = [1,stride,stride,1], padding = 'SAME' )

def ops_dropout(input_fully_connected,dropout):
    return tf.nn.dropout(input_fully_connected,dropout)

def ops_fullyconnected(input, activation = tf.nn.relu, use_activation = True):
    fc = tf.reshape(input,[-1,weights['wd1'].get_shape().as_list()[0]])
    fc = tf.add(tf.matmul(fc,weights['wd1']),biases['bd1'])
    if use_activation:
        return activation(fc)
    else:
        return fc

#DEFINE NETWORK ARCHTEKTURE (FORWARDPASS)

def build_network(x,weights,biases,dropout):
    x = tf.reshape(x,shape=(-1,28,28,1))

    conv_layer_1 = ops_conv2d(x,weights['wc1'],biases['bc1'],activation=tf.nn.relu, use_activation=True)
    conv_layer_1 = ops_maxpool2d(conv_layer_1,2)

    conv_layer_2 = ops_conv2d(conv_layer_1,weights['wc2'],biases['bc2'],activation=tf.nn.relu, use_activation=True)
    conv_layer_2 = ops_maxpool2d(conv_layer_2,2)

    fc1 = ops_fullyconnected(conv_layer_2, activation=tf.nn.relu, use_activation=True)
    fc1 = ops_dropout(fc1,dropout)

    logits = tf.add(tf.matmul(fc1,weights['out']),biases['out'],name = "output")

    return logits

#DEFINE TENSORFLOW BACKPROPAGATION OBJECTS (BACKWARDPASS)

logits = build_network(x,weights,biases,keep_prob)
#freeze_graph(os.curdir + "checkpoints" + os.sep, logits.name)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits = logits,labels = y))

#CHOSE AN OPTIMIZER
optimizer = tf.train.AdadeltaOptimizer(learning_rate=learning_rate).minimize(loss=loss)
predicted_labels = tf.equal(tf.argmax(logits,1),tf.argmax(y,1))

#EVALUATION PARAMETERS
acc = tf.reduce_mean(tf.cast(predicted_labels,tf.float32))

#NOW INITIALIZE ALL TF VARIABLES
init = tf.global_variables_initializer()

saver = tf.train.Saver(max_to_keep=10)



#NOW START THE SESSION AND EXECUTE THE GRAPH
with tf.Session() as sess:
    sess.run(init)

    for i in range(epochs):
        save_path = saver.save(sess, os.curdir + "checkpoints/MNIST_TEST.ckpt")
        for j in range(num_batches):
            batch_x, batch_y = mnist.train.next_batch(batch_size)
            sess.run(optimizer, feed_dict={x:batch_x,y:batch_y,keep_prob:dropout})

            losses,accs = sess.run([loss,acc],feed_dict={x:batch_x,y:batch_y,keep_prob:1.})

            if epochs % display_step == 0:
                print("EPOCH:",'%04d' % (i+1),
                      "loss =", "{:.9f}".format(losses),
                      "acc =", "{:.5f}".format(accs))
    print("TRAINING COMPLETED")
    #START PREDICTIONS
    predicted_label = sess.run(logits,feed_dict={x:mnist.test.images[:256],keep_prob:1.})
    test_classes = np.argmax(predicted_label,1)
    print("TEST ACCURACY:",sess.run(acc,feed_dict={x:mnist.test.images[:256], y:mnist.test.labels[:256],keep_prob:1.}))
    f,a = plt.subplots(1,10,figsize = (10,2))

    for i in range(10):
        a[i].imshow(np.reshape(mnist.test.images[i],(28,28)))
        print( test_classes[i])

    print("TOTAL EXAMPLE FINNISHED")

    #freeze_graph(os.curdir + "checkpoints"+os.sep,logits)
    print("freeze_graph_from_session: STARTED")
    freeze_graph_from_Session(sess,saver)
    print("freeze_graph_from_session: ENDED")

print("load_freeze_graph_from_session: STARTED")
load_frozen_graph_from_session()
print("load_freeze_graph_from_session: ENDED")

#with tf.Session() as sess:
#
#    sess.run(init)
#    graph = load_graph(os.curdir + os.sep + "checkpoints" + os.sep + "frozen_model.pb")
#    predicted_label = sess.run(logits, feed_dict={x: mnist.test.images[:256], keep_prob: 1.})
#    print(predicted_label)

感谢我。 :)