每个冻结图之间的Tensorflow结果不一致

时间:2018-03-23 16:33:37

标签: tensorflow protocol-buffers

当冻结图形然后在别处运行它(移动设备)时,与我的语义分割模型上的服务器上的推断相比,输出质量低。它基本上是服务器上运行的混乱版本。它正在成功执行,但似乎在冻结之前没有初始化某些东西,即使在导出脚本和推理脚本之间加载模型的方法几乎完全相同。

导出的模型可以反复在相同的图像上运行,并按预期为给定的图像集生成相同的结果。

但是,每次冻结模型时,使用完全相同的脚本和检查点,它会为给定的图像集创建不同的输出。

def main():
    args = get_arguments()

    if args.dataset == 'cityscapes':
        num_classes = cityscapes_class
    else:
        num_classes = ADE20k_class

    shape = [320, 320]

    x = tf.placeholder(dtype=tf.float32, shape=(shape[0], shape[1], 3), name="input")
    img_tf = preprocess(x)

    model = model_config[args.model]
    net = model({'data': img_tf}, num_classes=num_classes, filter_scale=args.filter_scale)

    raw_output = net.layers['conv6_cls']
    raw_output_up = tf.image.resize_bilinear(raw_output, size=shape, align_corners=True)
    raw_output_maxed = tf.argmax(raw_output_up, axis=3, name="output")

    # Init tf Session
    config = tf.ConfigProto()
    sess = tf.Session(config=config)
    init = tf.global_variables_initializer()

    sess.run(init)

    model_path = model_paths[args.model]
    ckpt = tf.train.get_checkpoint_state(model_path)
    if ckpt and ckpt.model_checkpoint_path:
        input_checkpoint = ckpt.model_checkpoint_path
        loader = tf.train.import_meta_graph(input_checkpoint + '.meta', clear_devices=True)
        load(loader, sess, ckpt.model_checkpoint_path)     
    else:
        print('No checkpoint file found at %s.' % model_path)
        exit()

    print("Loaded Model")

    # We retrieve the protobuf graph definition
    graph = tf.get_default_graph()
    input_graph_def = graph.as_graph_def()

    # We use a built-in TF helper to export variables to constants
    output_graph_def = graph_util.convert_variables_to_constants(
        sess, # The session is used to retrieve the weights
        input_graph_def, # The graph_def is used to retrieve the nodes
        output_node_names.split(",") # The output node names are used to select the usefull nodes
    )

    # Finally we serialize and dump the output graph to the filesystem
    with tf.gfile.GFile("model/output_graph.pb", "wb") as f:
        f.write(output_graph_def.SerializeToString())
    print("%d ops in the final graph." % len(output_graph_def.node))

0 个答案:

没有答案