是否可以从冻结的图形中删除批次尺寸?

时间:2019-07-04 15:11:35

标签: python tensorflow

检查冻结的张量流模型:

wget https://storage.googleapis.com/download.tensorflow.org/models/inception_v3_2016_08_28_frozen.pb.tar.gz

我看到输入大小为Tensor 'input:0', which has shape '(1, 299, 299, 3)',我想知道是否可以使输入(None, 299, 299, 3)使得batch_size> 1可用的批量预测吗?

1 个答案:

答案 0 :(得分:1)

在一般情况下,可能无法执行此操作,因为可能存在依赖于第一维为1的操作(例如,假设input:0上使用tf.squeeze)。但是,您可以尝试用所需形状的占位符替换输入。您可以使用tf.graph_util.import_graph_def进行此操作。如果操作允许,则TensorFlow应该导入图以相应地调整节点形状。请参见以下示例:

import tensorflow as tf

# First graph
with tf.Graph().as_default():
    x = tf.placeholder(tf.float32, [1, 10, 20], name='Input')
    y = tf.square(x, name='Output')
    print(y)
    # Tensor("Output:0", shape=(1, 10, 20), dtype=float32)
    gd = tf.get_default_graph().as_graph_def()

# Second graph
with tf.Graph().as_default():
    x = tf.placeholder(tf.float32, [None, 10, 20], name='Input')
    y, = tf.graph_util.import_graph_def(gd, input_map={'Input:0': x},
                                        return_elements=['Output:0'], name='')
    print(y)
    # Tensor("Output:0", shape=(?, 10, 20), dtype=float32)

在第一个图中,Output:0节点的形状为(1, 10, 20),这是根据Input:0张量的形状推断出来的。但是,当我从第一个图形中获取图形定义并加载到第二个图形中时,将Input:0张量替换为具有未定义的第一维的占位符,Output:0的形状将更新为{{1} }。如果我在第二个图中运行的操作中给出的输入值的第一维大于一个维度,则它将运行正常,因为该图是正确的。