我有一个tensorflow python MNIST训练脚本,该脚本可以生成冻结的* .pb文件进行推断。
import tensorflow as tf
from tensorflow.contrib import layers
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.python.framework.graph_util import convert_variables_to_constants
import utils
epochs = 250
batch_size = 55000 # Entire training set
# Import dataset
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
batches = int(len(mnist.train.images) / batch_size)
# Define Placeholders
image = tf.placeholder(tf.float32, [None, 784], name='image')
label = tf.placeholder(tf.float32, [None, 10], name='label')
# Define the model
layer1 = layers.fully_connected(image, 300)
layer2 = layers.fully_connected(layer1, 300)
logits = layers.fully_connected(layer2, 10)
# Create global step variable (needed for pruning)
global_step = tf.train.get_or_create_global_step()
reset_global_step_op = tf.assign(global_step, 0)
# Loss function
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=logits, labels=label))
# running this operation increments the global_step
train_op = tf.train.AdamOptimizer(learning_rate=1e-4).minimize(loss, global_step=global_step, name='train_op')
# Accuracy ops
correct_prediction = tf.equal(tf.argmax(logits, 1), tf.argmax(label, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name='accuracy')
# Create a saver for writing training checkpoints.
saver = tf.train.Saver()
with tf.Session() as sess:
# Uncomment the following if you don't have a trained model yet
sess.run(tf.initialize_all_variables())
# Train the model before pruning (optional)
for epoch in range(epochs):
for batch in range(batches):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_op, feed_dict={image: batch_xs, label: batch_ys})
# Calculate Test Accuracy every 10 epochs
if epoch % 10 == 0:
acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
print("Normal Train Model step %d test accuracy %g" % (epoch, acc_print))
acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
print("Normal Train Model accuracy:", acc_print)
# Save to full converted pb file
graph = convert_variables_to_constants(sess, sess.graph_def, ["accuracy"])
with tf.gfile.FastGFile('normal_train2.pb', mode='wb') as f:
f.write(graph.SerializeToString())
我能做的是,使用此* .pb作为推断,以便我可以评估一些图像。
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import os
import utils
import sys
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
session = tf.Session(config=config)
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
def load_model(path_to_model):
if not os.path.exists(path_to_model):
raise ValueError("'path_to_model.pb' is not exist.")
model_graph = tf.Graph()
with model_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(path_to_model, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
return model_graph
def main(path_to_model):
print("path_to_model:", path_to_model)
model_graph = load_model(path_to_model)
accuracy = model_graph.get_tensor_by_name('accuracy:0')
image = model_graph.get_tensor_by_name('image:0')
label = model_graph.get_tensor_by_name('label:0')
equal = model_graph.get_tensor_by_name('Equal:0')
logits = model_graph.get_tensor_by_name('ArgMax:0')
with tf.Session(graph=model_graph) as sess:
acc_print = sess.run(accuracy, feed_dict={image: mnist.test.images, label: mnist.test.labels})
print("eval_from_pb accuracy:", acc_print)
if __name__ == "__main__":
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
main(*sys.argv[1:])
另外,通过查看* .pbtxt并打印出操作,我知道我可以使用logits = model_graph.get_tensor_by_name('ArgMax:0')
来实际打印图像预测标签。
我想问的是 是否可以通过* .pb文件执行 train_op ? 我不需要持续培训,只想知道如何使用* .pb文件恢复 train_op ,因此我可以从头开始。
我已经读过Re-train a frozen *.pb model in TensorFlow和Cloning a network with tf.contrib.graph_editor,它们都不是有效的例子,我已经尝试过了,但并不幸运。