在现有的tensorflow图中设置静态形状,其中动态形状用于输入

时间:2019-01-17 23:32:22

标签: python tensorflow

我得到了一个冻结的( .pb)图,该图具有动态的输入形状(例如“无,无,无,3”或“?x?x?x3”)。我想将它们设置为静态形状(例如“ 1、320、320、3”),但是我不确定如何将形状更改为输入占位符,以将更改应用于随后的所有图层。在这种情况下,我没有可用的代码或ckpt文件,因此必须在冻结的( .pb)图上进行这项工作。

我已经尝试了什么?

我创建了一个简单的示例代码来制作一个简单的图形并将其保存为冻结的图形,可用于测试不同的方法。此图创建如下:

import tensorflow as tf

def simple_cnn_graph():
    graph = tf.Graph()
    with graph.as_default():
        input_layer = tf.placeholder(shape=[None, None, None, 3], dtype=tf.float32)
        conv1 = tf.layers.conv2d(
            inputs=input_layer,
            filters=32,
            kernel_size=[5, 5],
            padding="same",
            activation=tf.nn.relu)
        pool1 = tf.layers.max_pooling2d(inputs=conv1, pool_size=[2, 2], strides=2, name='pool1')
        conv2 = tf.layers.conv2d(
            inputs=pool1,
            filters=16,
            kernel_size=[5, 5],
            padding="same",
            activation=tf.nn.relu)
        pool2 = tf.layers.max_pooling2d(inputs=conv2, pool_size=[2, 2], strides=2, name='pool2')
    return graph, pool2

if __name__=='__main__':
    graph, output = simple_cnn_graph()
    with tf.Session(graph=graph) as sess:
        sess.run(tf.global_variables_initializer())
        graph_def = tf.graph_util.convert_variables_to_constants(sess, \
                            tf.get_default_graph().as_graph_def(), [output.name.split(':')[0]])

    frozen_file='./frozen.pb'
    with open(frozen_file, 'wb') as f:
        f.write(graph_def.SerializeToString())

    print([n.name for n in graph.as_graph_def().node])

我尝试了两种方法:

1)我尝试在以下位置使用transform_graph工具: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/graph_transforms/README.md 使用strip_unused_nodes,但是它不起作用,因为我没有将任何张量转换为占位符。

2)在链接中的注释之后,我取得了一些成功: https://github.com/tensorflow/tensorflow/issues/5680#issuecomment-405128390 我可以使用tf.import_graph_def的{​​{1}}来映射新的占位符的位置,但是我正在寻找一种更简单,可通用的解决方案,将来可以将其应用于任何此类冻结网络图(例如类似于transform_graph)。以下是我使用input_map方法的代码

tf.import_graph_def

打印输出为:

import tensorflow as tf

def load_frozen_graph(frozen_file='frozen.pb'):
    graph = tf.Graph()
    with graph.as_default():
      od_graph_def = tf.GraphDef()
      with tf.gfile.GFile(frozen_file, 'rb') as fid:
        serialized_graph = fid.read()
        od_graph_def.ParseFromString(serialized_graph)
        tf.import_graph_def(od_graph_def, name='')
    return graph

graph = load_frozen_graph('./frozen.pb')

print('Tensor shapes before import map')
input_tensor = graph.get_tensor_by_name('Placeholder:0')
print(input_tensor)
output_tensor = graph.get_tensor_by_name('pool2/MaxPool:0')
print(output_tensor)

new_graph = tf.Graph()
with new_graph.as_default():
    new_input = tf.placeholder(dtype=tf.float32, shape=[1, 320, 320, 3], name='Placeholder')
    tf.import_graph_def(graph.as_graph_def(), name='', input_map={'Placeholder': new_input})

print('Tensor shapes after import map')
input_tensor = new_graph.get_tensor_by_name('Placeholder:0')
print(input_tensor)
output_tensor = new_graph.get_tensor_by_name('pool2/MaxPool:0')
print(output_tensor)

如果有人可以将我指向正确的方向或纠正我,如果我在上述代码/帖子中犯了任何错误或对tf形状有任何错误的理解,我将不胜感激。

0 个答案:

没有答案