从张量流

时间:2018-01-29 07:43:04

标签: python tensorflow

在tensorflow中,我需要从inception_v3预训练模型加载权重,以便在以下代码中使用:

with tf.variable_scope(scope, "InceptionV3", [images]) as scope:
            with slim.arg_scope(
                [slim.conv2d, slim.fully_connected],
                weights_regularizer=weights_regularizer,
                trainable=False):
                with slim.arg_scope(
                    [slim.conv2d],
                    weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
                    activation_fn=tf.nn.relu,
                    normalizer_fn=slim.batch_norm,
                    normalizer_params=batch_norm_params):
                    net, end_points = inception_v3_base(images, scope=scope)
                with tf.variable_scope("logits"):
                    shape = net.get_shape()
                    net = slim.avg_pool2d(net, shape[1:3], padding="VALID", scope="pool")
                    net = slim.dropout(
                            net,
                            keep_prob=dropout_keep_prob,
                            is_training=False,
                            scope="dropout")
                    net = slim.flatten(net, scope="flatten")

    image_embeddings = tf.contrib.layers.fully_connected(
                    inputs=net,
                                num_outputs=512,
                                activation_fn=None,
                                weights_initializer=initializer,
                                biases_initializer=None,
                                scope=scope)

怎么可能这样做?你能举个简短的例子吗?

上述代码中有两个权重初始值设定项。我不知道我必须从模型中初始化权重,以及如何?

谢谢,

1 个答案:

答案 0 :(得分:3)

TL; DR :阅读下面列表中的第三点。

关于如何恢复模型的长期通用解释

每当您需要从检查点加载权重时,您需要匹配的模型定义才能在尝试恢复权重之前定义图形。这是必要的,因为检查点文件只包含变量的值,它没有关于图形本身结构的信息

可以通过不同方式检索模型结构:

  • 检查点附带匹配的.meta文件。在这种情况下,导入元图,然后通过以下方法恢复权重:

    new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
    new_saver.restore(sess, 'my-save-dir/my-model-10000')
    
  • 检查点附带一个匹配的.pb / .pbtxt文件,其中包含序列化的GraphDef。在这种情况下,从其定义加载图形,然后恢复权重:

    • .pbtxt

      with open('graph.pbtxt', 'r') as f:
          graph_def = tf.GraphDef()
          file_content = f.read()
          text_format.Merge(file_content, graph_def)
          tf.import_graph_def(graph_def, name='')
      saver = tf.train.Saver() # note: it is important that this is defined AFTER you import the graph definition or it won't find any variables in the graph to restore
      saver.restore(sess, "/tmp/model.ckpt")
      
    • .pb

      with gfile.FastGFile('graph.pb','rb') as f:
          graph_def = tf.GraphDef()
          graph_def.ParseFromString(f.read())
          tf.import_graph_def(graph_def, name='')
      saver = tf.train.Saver() # note: it is important that this is defined AFTER you import the graph definition or it won't find any variables in the graph to restore
      saver.restore(sess, "/tmp/model.ckpt")
      
  • 检查点附带一个包含模型定义的匹配python文件。在这种情况下,请阅读文件的文档并找到需要调用以定义模型的函数。然后,在您的脚本中导入该函数,在定义saver之前调用它,然后恢复变量'来自检查站的值:

    from inception_v3 import inception_v3
    
    logits, endpoints = inception_v3()
    saver = tf.train.Saver() # as above, it is important that this is defined after you define the graph, or it won't find any variables.
    saver.restore(sess, 'inception_v3.ckpt')
    

    注意:对于这种情况,您需要在保存检查点时调用函数完全(除非您选择性地尝试恢复某些变量),或者恢复将失败并显示错误。