在tensorflow中,我需要从inception_v3预训练模型加载权重,以便在以下代码中使用:
with tf.variable_scope(scope, "InceptionV3", [images]) as scope:
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=weights_regularizer,
trainable=False):
with slim.arg_scope(
[slim.conv2d],
weights_initializer=tf.truncated_normal_initializer(stddev=stddev),
activation_fn=tf.nn.relu,
normalizer_fn=slim.batch_norm,
normalizer_params=batch_norm_params):
net, end_points = inception_v3_base(images, scope=scope)
with tf.variable_scope("logits"):
shape = net.get_shape()
net = slim.avg_pool2d(net, shape[1:3], padding="VALID", scope="pool")
net = slim.dropout(
net,
keep_prob=dropout_keep_prob,
is_training=False,
scope="dropout")
net = slim.flatten(net, scope="flatten")
image_embeddings = tf.contrib.layers.fully_connected(
inputs=net,
num_outputs=512,
activation_fn=None,
weights_initializer=initializer,
biases_initializer=None,
scope=scope)
怎么可能这样做?你能举个简短的例子吗?
上述代码中有两个权重初始值设定项。我不知道我必须从模型中初始化权重,以及如何?
谢谢,
答案 0 :(得分:3)
TL; DR :阅读下面列表中的第三点。
每当您需要从检查点加载权重时,您需要匹配的模型定义才能在尝试恢复权重之前定义图形。这是必要的,因为检查点文件只包含变量的值,它没有关于图形本身结构的信息。
可以通过不同方式检索模型结构:
检查点附带匹配的.meta
文件。在这种情况下,导入元图,然后通过以下方法恢复权重:
new_saver = tf.train.import_meta_graph('my-save-dir/my-model-10000.meta')
new_saver.restore(sess, 'my-save-dir/my-model-10000')
检查点附带一个匹配的.pb
/ .pbtxt
文件,其中包含序列化的GraphDef
。在这种情况下,从其定义加载图形,然后恢复权重:
.pbtxt
:
with open('graph.pbtxt', 'r') as f:
graph_def = tf.GraphDef()
file_content = f.read()
text_format.Merge(file_content, graph_def)
tf.import_graph_def(graph_def, name='')
saver = tf.train.Saver() # note: it is important that this is defined AFTER you import the graph definition or it won't find any variables in the graph to restore
saver.restore(sess, "/tmp/model.ckpt")
.pb
:
with gfile.FastGFile('graph.pb','rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
tf.import_graph_def(graph_def, name='')
saver = tf.train.Saver() # note: it is important that this is defined AFTER you import the graph definition or it won't find any variables in the graph to restore
saver.restore(sess, "/tmp/model.ckpt")
检查点附带一个包含模型定义的匹配python文件。在这种情况下,请阅读文件的文档并找到需要调用以定义模型的函数。然后,在您的脚本中导入该函数,在定义saver
之前调用它,然后恢复变量'来自检查站的值:
from inception_v3 import inception_v3
logits, endpoints = inception_v3()
saver = tf.train.Saver() # as above, it is important that this is defined after you define the graph, or it won't find any variables.
saver.restore(sess, 'inception_v3.ckpt')
注意:对于这种情况,您需要在保存检查点时调用函数完全(除非您选择性地尝试恢复某些变量),或者恢复将失败并显示错误。