使用变量更改保存/恢复

时间:2018-07-05 07:14:25

标签: python tensorflow

我是Tensorflow的新手。出于学习目的,我使用VGG16。 首先,我从文件加载经过训练的权重,而我不想训练整个网络,因此我冻结除了最后两层以外的所有层的权重。

weightsFile = {'conv1_1_W': tf.Variable(weights['conv1_1_W'], name='conv1_1_W', trainable=False),
               'conv1_2_W': tf.Variable(weights['conv1_2_W'], name='conv1_2_W', trainable=False),
               'conv2_1_W': tf.Variable(weights['conv2_1_W'], name='conv2_1_W', trainable=False),
               'conv2_2_W': tf.Variable(weights['conv2_2_W'], name='conv2_2_W', trainable=False),
               'conv3_1_W': tf.Variable(weights['conv3_1_W'], name='conv3_1_W', trainable=False),
               'conv3_2_W': tf.Variable(weights['conv3_2_W'], name='conv3_2_W', trainable=False),
               'conv3_3_W': tf.Variable(weights['conv3_3_W'], name='conv3_3_W', trainable=False),
               'conv4_1_W': tf.Variable(weights['conv4_1_W'], name='conv4_1_W', trainable=False),
               'conv4_2_W': tf.Variable(weights['conv4_2_W'], name='conv4_2_W', trainable=False),
               'conv4_3_W': tf.Variable(weights['conv4_3_W'], name='conv4_3_W', trainable=False),
               'conv5_1_W': tf.Variable(weights['conv5_1_W'], name='conv5_1_W', trainable=False),
               'conv5_2_W': tf.Variable(weights['conv5_2_W'], name='conv5_2_W', trainable=False),
               'conv5_3_W': tf.Variable(weights['conv5_3_W'], name='conv5_3_W', trainable=False),
               'fc6_W':tf.Variable(weights['fc6_W'], name='fc6_W', trainable=False),
               'fc7_W':tf.Variable(weights['fc7_W'], name='fc7_W'),
               'fc8_W':tf.Variable(weights['fc8_W'], name='fc8_W'),
               'out_W': tf.Variable(tf.random_normal([1000, 1]), name='out_W'),
               'conv1_1_b': tf.Variable(weights['conv1_1_b'], name='conv1_1_b'),
               'conv1_2_b': tf.Variable(weights['conv1_2_b'], name='conv1_2_b'),
               'conv2_1_b': tf.Variable(weights['conv2_1_b'], name='conv2_1_b'),
               'conv2_2_b': tf.Variable(weights['conv2_2_b'], name='conv2_2_b'),
               'conv3_1_b': tf.Variable(weights['conv3_1_b'], name='conv3_1_b'),
               'conv3_2_b': tf.Variable(weights['conv3_2_b'], name='conv3_2_b'),
               'conv3_3_b': tf.Variable(weights['conv3_3_b'], name='conv3_3_b'),
               'conv4_1_b': tf.Variable(weights['conv4_1_b'], name='conv4_1_b'),
               'conv4_2_b': tf.Variable(weights['conv4_2_b'], name='conv4_2_b'),
               'conv4_3_b': tf.Variable(weights['conv4_3_b'], name='conv4_3_b'),
               'conv5_1_b': tf.Variable(weights['conv5_1_b'], name='conv5_1_b'),
               'conv5_2_b': tf.Variable(weights['conv5_2_b'], name='conv5_2_b'),
               'conv5_3_b': tf.Variable(weights['conv5_3_b'], name='conv5_3_b'),
               'fc6_b':tf.Variable(weights['fc6_b'], name='fc6_b'),
               'fc7_b':tf.Variable(weights['fc7_b'], name='fc7_b'),
               'fc8_b':tf.Variable(weights['fc8_b'], name='fc8_b'),
               'out_b': tf.Variable(tf.random_normal([1]), name='out_b')
               }

然后这是培训摘要

prediction = convolutional_neural_network(x_1, x_2, weightsFile)
cost = tf.losses.mean_squared_error(labels=y_1, predictions=prediction)
tf.summary.histogram("optimizer", cost)
learning_rate = 0.001
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)

我需要保存模态,所以我正在使用保护程序

saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.restore(sess, restore_path)
    o, c = sess.run([optimizer, cost],
                                     feed_dict={x_1: epoch_x1, x_2: epoch_x2, y_1: epoch_y})
    save_path = saver.save(sess, model_path)

总体而言,这是代码,我省略了一些不需要理解逻辑的东西。它在给定的数据上运行良好并且正在培训网络。我想解冻fc6_W时出现问题。它说

  

在检查点中找不到密钥fc6_W / Adam

我不确定如何实现这一目标。以及在我需要逐步解冻图层的情况下如何保存和还原。

0 个答案:

没有答案