我找到了以下代码段来显示保存到*.pb
文件的模型:
model_filename ='saved_model.pb'
with tf.Session() as sess:
with gfile.FastGFile(path_to_model_pb, 'rb') as f:
data = compat.as_bytes(f.read())
sm = saved_model_pb2.SavedModel()
sm.ParseFromString(data)
g_in = tf.import_graph_def(sm.meta_graphs[0].graph_def)
LOGDIR='.'
train_writer = tf.summary.FileWriter(LOGDIR)
train_writer.add_graph(sess.graph)
现在我正在努力创建saved_model.pb
。如果我的session.run看起来像这样:
_, cr_loss = sess.run([train_op,cross_entropy_loss],
feed_dict={input_image: images,
correct_label: gt_images,
keep_prob: KEEP_PROB,
learning_rate: LEARNING_RATE}
)
如何将train_op
中包含的图表保存到saved_model.pb
?
答案 0 :(得分:6)
最简单的方法是使用tf.train.write_graph
。通常,您只需要执行以下操作:
tf.train.write_graph(my_graph, path_to_model_pb,
'saved_model.pb', as_text=False)
如果您使用默认图表或任何其他tf.get_default_graph()
(或tf.Graph
)对象,则 my_graph
可以是tf.GraphDef
。
请注意,这会保存图形定义,可以将其可视化,但如果您有变量,除非先freeze the graph,否则它们的值将不会保存在那里(因为它们仅在会话对象中,而不是图表本身)。
答案 1 :(得分:1)
我将逐步介绍此问题:
要显示变量(如权重),偏差使用 tf.summary.histogram
weights = {
'h1': tf.Variable(tf.random_normal([n_input, n_hidden_1])),
'h2': tf.Variable(tf.random_normal([n_hidden_1, n_hidden_2])),
'out': tf.Variable(tf.random_normal([n_hidden_2, n_classes]))
}
tf.summary.histogram("weight1", weights['h1'])
tf.summary.histogram("weight2", weights['h2'])
tf.summary.histogram("weight3", weights['out'])
biases = {
'b1': tf.Variable(tf.random_normal([n_hidden_1])),
'b2': tf.Variable(tf.random_normal([n_hidden_2])),
'out': tf.Variable(tf.random_normal([n_classes]))
}
tf.summary.histogram("bias1", biases['b1'])
tf.summary.histogram("bias2", biases['b2'])
tf.summary.histogram("bias3", biases['out'])
cost = tf.sqrt(tf.reduce_mean(tf.squared_difference(pred, y)))
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
tf.summary.scalar('rmse', cost)
然后培训包括以下代码。
summaries = tf.summary.merge_all()
with tf.Session() as sess:
sess.run(init)
# Get data
writer = tf.summary.FileWriter("histogram_example", sess.graph)
# Training cycle
# Run optimization op (backprop) and cost op (to get loss value)
summ, p, _, c = sess.run([summ, pred, optimizer, cost], feed_dict={x: batch_x,
y: batch_y,})
writer.add_summary(summ, global_step=epoch*total_batch+i)