Tensorflow:如何在新图表中使用预训练权重?

时间:2018-05-07 13:22:15

标签: python tensorflow

我正在尝试使用具有python框架的tensorflow来构建具有CNN的对象检测器。我想训练我的模型首先进行物体识别(分类),然后使用预定模型的几个卷积层训练它来预测边界框。我将需要替换完全连接的层,可能还需要一些最后的卷积层。所以,出于这个原因,我想知道是否有可能从张量流图中导入权重,用于训练对象分类器到我将训练进行对象检测的新定义的图形。所以基本上我想做这样的事情:

# here I initialize the new graph
conv_1=tf.nn.conv2d(in, weights_from_old_graph)
conv_2=tf.nn.conv2d(conv_1, weights_from_old_graph)
...
conv_n=tf.nn.nnconv2d(conv_n-1,randomly_initialized_weights)
fc_1=tf.matmul(conv_n, randomly_initalized_weights)

2 个答案:

答案 0 :(得分:7)

使用没有参数的保护程序来保存整个模型。

tf.reset_default_graph()
v1 = tf.get_variable("v1", [3], initializer = tf.initializers.random_normal)
v2 = tf.get_variable("v2", [5], initializer = tf.initializers.random_normal)
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess, save_path='./test-case.ckpt')

    print(v1.eval())
    print(v2.eval())
saver = None
v1 = [ 2.1882825   1.159807   -0.26564872]
v2 = [0.11437789 0.5742971 ]

然后,在要还原到某些值的模型中,将要还原的变量名称列表或{"variable name": variable}的字典传递给Saver

tf.reset_default_graph()
b1 = tf.get_variable("b1", [3], initializer= tf.initializers.random_normal)
b2 = tf.get_variable("b2", [3], initializer= tf.initializers.random_normal)
saver = tf.train.Saver(var_list={'v1': b1})

with tf.Session() as sess:
  saver.restore(sess, "./test-case.ckpt")
  print(b1.eval())
  print(b2.eval())
INFO:tensorflow:Restoring parameters from ./test-case.ckpt
b1 = [ 2.1882825   1.159807   -0.26564872]
b2 = FailedPreconditionError: Attempting to use uninitialized value b2

答案 1 :(得分:0)

虽然我同意Aechlys恢复变数。当我们想要修复这些变量时,问题就更难了。例如,我们训练了这些变量,我们希望在另一个模型中使用它们,但这次没有训练它们(训练新的变量,如转移学习)。你可以看到我发布的答案here

快速举例:

 with tf.session() as sess:
    new_saver = tf.train.import_meta_graph(pathToMeta)
    new_saver.restore(sess, pathToNonMeta) 

    weight1 = sess.run(sess.graph.get_tensor_by_name("w1:0")) 


 tf.reset_default_graph() #this will eliminate the variables we restored


 with tf.session() as sess:
    weights = 
       {
       '1': tf.Variable(weight1 , name='w1-bis', trainable=False)
       }
...

我们现在确定恢复的变量不是图表的一部分。