TensorFlow:冻结模型似乎只存储输出节点?

时间:2017-09-16 18:48:13

标签: python tensorflow save

我正在尝试冻结我学到的Tensorflow模型。这是从教程Deep MNIST for Experts中获取的:

def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)


def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)


def conv2d(x_vector, w_matrix):
    return tf.nn.conv2d(x_vector, w_matrix, strides=[1, 1, 1, 1], padding='SAME')


def max_pool_2x2(x_vector):
    return tf.nn.max_pool(x_vector, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')


output_graph_name = 'my_graph.pb'

# Create model

label_count = 12

x = tf.placeholder(tf.float32, shape=[None, 1024], name="x")
y_ = tf.placeholder(tf.float32, shape=[None, label_count], name="y_")

w_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])
x_image = tf.reshape(x, [-1, 32, 32, 1])

h_conv1 = tf.nn.relu(conv2d(x_image, w_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)

w_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])

h_conv2 = tf.nn.relu(conv2d(h_pool1, w_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)

w_fc1 = weight_variable([8 * 8 * 64, 1024])
b_fc1 = bias_variable([1024])

h_pool2_flat = tf.reshape(h_pool2, [-1, 8 * 8 * 64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, w_fc1) + b_fc1)

keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

w_fc2 = weight_variable([1024, label_count])
b_fc2 = bias_variable([label_count])

y_conv = tf.matmul(h_fc1_drop, w_fc2) + b_fc2

cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

# Add ops to save and restore all the variables.
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    # ... train, etc etc ...
    # train_step.run(feed_dict={x: train_set[0],
    #                           y_: train_set[1],
    #                           keep_prob: 0.5})

    # Save the variables to disk.
    save_path = saver.save(sess, "my_model.ckpt")

    # Save graph
    tf.train.write_graph(sess.graph_def, '.', output_graph_name, as_text=False)

然后我尝试冻结我的模型

from tensorflow.python.tools import freeze_graph

freeze_graph.freeze_graph(input_graph=output_graph_name,
                          input_saver="",
                          input_binary=True,
                          input_checkpoint="my_model",
                          output_node_names="y_",
                          restore_op_name="save/restore_all",
                          filename_tensor_name="save/Const:0",
                          output_graph="frozen_graph.pb",
                          clear_devices=True,
                          initializer_nodes="")

现在,冻结图只包含y_占位符,而不是整个关联网络。 graph_util.extract_sub_graph仅提取y_。为什么会这样?如何冻结整个网络?我应该使用conv_y代替y_吗?

热门“占位符”节点适用于y_x的底部“占位符”节点 computation graph

1 个答案:

答案 0 :(得分:2)

再次说服你y_确实不是输出节点添加以下代码:

# dump graph
def childs(t, d=0):
    print '-' * d, t.name
    for child in t.op.inputs:
        childs(child, d + 1)
childs(accuracy)

输出

Mean_1:0
- Cast_1:0
-- Equal:0
--- ArgMax:0
---- add_3:0
----- MatMul_1:0
------ dropout/mul:0
------- dropout/div:0
-------- Relu_2:0
--------- add_2:0
---------- MatMul:0
----------- Reshape_1:0
------------ MaxPool_1:0
------------- Relu_1:0
-------------- add_1:0
--------------- Conv2D_1:0
---------------- MaxPool:0
----------------- Relu:0
------------------ add:0
------------------- Conv2D:0
-------------------- Reshape:0
--------------------- x:0
--------------------- Reshape/shape:0
-------------------- Variable/read:0
--------------------- Variable:0
------------------- Variable_1/read:0
-------------------- Variable_1:0
---------------- Variable_2/read:0
----------------- Variable_2:0
--------------- Variable_3/read:0
---------------- Variable_3:0
------------ Reshape_1/shape:0
----------- Variable_4/read:0
------------ Variable_4:0
---------- Variable_5/read:0
----------- Variable_5:0
-------- Placeholder:0
------- dropout/Floor:0
-------- dropout/add:0
--------- Placeholder:0
--------- dropout/random_uniform:0
---------- dropout/random_uniform/mul:0
----------- dropout/random_uniform/RandomUniform:0
------------ dropout/Shape:0
------------- Relu_2:0
-------------- add_2:0
--------------- MatMul:0
---------------- Reshape_1:0
----------------- MaxPool_1:0
------------------ Relu_1:0
------------------- add_1:0
-------------------- Conv2D_1:0
--------------------- MaxPool:0
---------------------- Relu:0
----------------------- add:0
------------------------ Conv2D:0
------------------------- Reshape:0
-------------------------- x:0
-------------------------- Reshape/shape:0
------------------------- Variable/read:0
-------------------------- Variable:0
------------------------ Variable_1/read:0
------------------------- Variable_1:0
--------------------- Variable_2/read:0
---------------------- Variable_2:0
-------------------- Variable_3/read:0
--------------------- Variable_3:0
----------------- Reshape_1/shape:0
---------------- Variable_4/read:0
----------------- Variable_4:0
--------------- Variable_5/read:0
---------------- Variable_5:0
----------- dropout/random_uniform/sub:0
------------ dropout/random_uniform/max:0
------------ dropout/random_uniform/min:0
---------- dropout/random_uniform/min:0
------ Variable_6/read:0
------- Variable_6:0
----- Variable_7/read:0
------ Variable_7:0
---- ArgMax/dimension:0
--- ArgMax_1:0
---- y_:0
---- ArgMax_1/dimension:0
- Const_5:0

解决方案是

tf.identity(y_conv, name='my_output')
freeze_graph.freeze_graph(input_graph=output_graph_name,
                          input_saver="",
                          input_binary=True,
                          input_checkpoint="my_model",
                          output_node_names="my_output",
                          restore_op_name="save/restore_all",
                          filename_tensor_name="save/Const:0",
                          output_graph="frozen_graph.pb",
                          clear_devices=True,
                          initializer_nodes="")