使用MNIST示例重新训练冻结的pb文件

时间:2018-12-03 10:26:11

标签: python tensorflow mnist

我有一个tensorflow python MNIST训练脚本,该脚本可以生成冻结的* .pb文件进行推断。

import tensorflow as tf
from tensorflow.contrib import layers
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.framework.graph_util import convert_variables_to_constants
import utils

epochs = 250
batch_size = 55000 # Entire training set

# Import dataset
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batches = int(len(mnist.train.images) / batch_size)

# Define Placeholders
image = tf.placeholder(tf.float32, [None, 784], name='image')
label = tf.placeholder(tf.float32, [None, 10], name='label')

# Define the model
layer1 = layers.fully_connected(image, 300)
layer2 = layers.fully_connected(layer1, 300)
logits = layers.fully_connected(layer2, 10)

# Create global step variable (needed for pruning)
global_step = tf.train.get_or_create_global_step()
reset_global_step_op = tf.assign(global_step, 0)

# Loss function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=label))

# running this operation increments the global_step
train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss, global_step=global_step, name='train_op')

# Accuracy ops
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy')

# Create a saver for writing training checkpoints.
saver = tf.train.Saver()

with tf.Session() as sess:
    # Uncomment the following if you don't have a trained model yet
    sess.run(tf.initialize_all_variables())

    # Train the model before pruning (optional)
    for epoch in range(epochs):
        for batch in range(batches):
            batch_xs, batch_ys = mnist.train.next_batch(batch_size)
            sess.run(train_op, feed_dict={image: batch_xs, label: batch_ys})

        # Calculate Test Accuracy every 10 epochs
        if epoch % 10 == 0:
            acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
            print("Normal Train Model step %d test accuracy %g" % (epoch, acc_print))

    acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
    print("Normal Train Model accuracy:", acc_print)

    # Save to full converted pb file
    graph = convert_variables_to_constants(sess, sess.graph_def, ["accuracy"])
    with tf.gfile.FastGFile('normal_train2.pb', mode='wb') as f:
        f.write(graph.SerializeToString())

我能做的是,使用此* .pb作为推断,以便我可以评估一些图像。

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
import utils
import sys

config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)

mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

def load_model(path_to_model):
    if not os.path.exists(path_to_model):
        raise ValueError("'path_to_model.pb' is not exist.")

    model_graph = tf.Graph()
    with model_graph.as_default():
        od_graph_def = tf.GraphDef()
        with tf.gfile.GFile(path_to_model, 'rb') as fid:
            serialized_graph = fid.read()
            od_graph_def.ParseFromString(serialized_graph)
            tf.import_graph_def(od_graph_def, name='')
    return model_graph

def main(path_to_model):    
    print("path_to_model:", path_to_model)
    model_graph = load_model(path_to_model)

    accuracy = model_graph.get_tensor_by_name('accuracy:0')
    image = model_graph.get_tensor_by_name('image:0')
    label = model_graph.get_tensor_by_name('label:0')

    equal = model_graph.get_tensor_by_name('Equal:0')
    logits = model_graph.get_tensor_by_name('ArgMax:0')

    with tf.Session(graph=model_graph) as sess:
        acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
        print("eval_from_pb accuracy:", acc_print)

if __name__ == "__main__":
    os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
    main(*sys.argv[1:])    

另外,通过查看* .pbtxt并打印出操作,我知道我可以使用logits = model_graph.get_tensor_by_name('ArgMax:0')来实际打印图像预测标签。

我想问的是 是否可以通过* .pb文件执行 train_op 我不需要持续培训,只想知道如何使用* .pb文件恢复 train_op ,因此我可以从头开始。

我已经读过Re-train a frozen *.pb model in TensorFlowCloning a network with tf.contrib.graph_editor,它们都不是有效的例子,我已经尝试过了,但并不幸运。

0 个答案:

没有答案