从张量流模型中获取权重

时间:2016-04-04 20:14:11

标签: python neural-network tensorflow conv-neural-network

您好我想从tensorflow微调VGG模型。我有两个问题。

如何从网络获取权重? trainable_variables为我返回空列表。

我使用现有的模型:https://github.com/ry/tensorflow-vgg16。 我找到关于获得权重的帖子,但是由于import_graph_def,这对我不起作用。 Get the value of some weights in a model trained by TensorFlow

import tensorflow as tf
import PIL.Image
import numpy as np

with open("../vgg16.tfmodel", mode='rb') as f:
  fileContent = f.read()

graph_def = tf.GraphDef()
graph_def.ParseFromString(fileContent)

images = tf.placeholder("float", [None, 224, 224, 3])

tf.import_graph_def(graph_def, input_map={ "images": images })
print("graph loaded from disk")

graph = tf.get_default_graph()

cat = np.asarray(PIL.Image.open('../cat224.jpg'))
print(cat.shape)
init = tf.initialize_all_variables()

with tf.Session(graph=graph) as sess:
  print(tf.trainable_variables() )
  sess.run(init)

1 个答案:

答案 0 :(得分:6)

pretrained VGG-16 model将所有模型参数编码为tf.constant()操作。 (例如,请参阅对tf.constant() here的调用。)因此,模型参数不会出现在tf.trainable_variables()中,并且模型在没有大量手术的情况下是不可变的:你会需要使用以相同值开头的tf.Variable对象替换常量节点,以便继续训练。

通常,在导入用于重新训练的图形时,应使用tf.train.import_meta_graph()函数,因为此函数会加载其他元数据(包括变量集合)。 tf.import_graph_def()功能较低,不会填充这些集合。