在检查点中找不到alexnet_v2 / conv1 /偏差

时间:2017-10-30 08:42:03

标签: tensorflow tensorflow-slim

我尝试保存并恢复alxenet slim模型。但是当我运行代码saver.restore(sess, tf.train.latest_checkpoint('I:/model/mnist/'))时,我总是遇到这个错误。 并抛出错误: NotFoundError (see above for traceback): Key alexnet_v2/conv1/biases not found in checkpoint。 当我运行tf.global_variables()时,我只能获得conv2d的权重,并且结果中没有偏差。 我不明白问题是什么。这是我的代码:

这是我的亚历克斯模型

def alexnet_v2_arg_scope(weight_decay=0.0005,
                     stddev=0.1,
                     batch_norm_var_collection='moving_vars',
                     use_fused_batchnorm=True):
batch_norm_params = {
    # Decay for the moving averages.
    'decay': 0.9997,
    # epsilon to prevent 0s in variance.
    'epsilon': 0.001,
    # collection containing update_ops.
    'updates_collections': ops.GraphKeys.UPDATE_OPS,
    # Use fused batch norm if possible.
    'fused': use_fused_batchnorm,
    # collection containing the moving mean and moving variance.
    'variables_collections': {
        'beta': None,
        'gamma': None,
        'moving_mean': [batch_norm_var_collection],
        'moving_variance': [batch_norm_var_collection],
    }
}
with slim.arg_scope([slim.conv2d, slim.fully_connected],
                    activation_fn=tf.nn.relu,
                    biases_initializer=tf.constant_initializer,
                    normalizer_fn=slim.batch_norm,
                    normalizer_params=batch_norm_params,
                    weights_regularizer=slim.l2_regularizer(weight_decay)):
    with slim.arg_scope([slim.conv2d], padding='SAME'):
        with slim.arg_scope([slim.max_pool2d], padding='VALID') as arg_sc:
            return arg_sc

def alex_net(inputs,
         num_classes=10,
         is_training=True,
         droupout_keep_prob=0.5,
         spatial_squeze=True,
         scope='alexnet_v2'):
with tf.variable_scope(scope, 'alexnet_v2',[inputs]) as sc:
    end_points_collection = sc.name + '_end_points'
    with slim.arg_scope([slim.conv2d],
                        weights_initializer=trunc_norm(0.1),
                        biases_initializer=tf.constant_initializer(0.1),
                        outputs_collections=end_points_collection):
        inputs = tf.reshape(inputs,[-1,28,28,1])
        net = slim.conv2d(inputs, 64, [3, 3], 1, padding='VALID', scope='conv1')
        net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
        net = slim.conv2d(net, 128, [3, 3], scope='conv2')
        net = slim.max_pool2d(net, [3, 3], 2, scope='pool2')
        # net = slim.conv2d(net, 384, [3, 3], scope='conv3')
        # net = slim.conv2d(net, 384, [3, 3], scope='conv4')
        net = slim.conv2d(net, 256, [3, 3], scope='conv3')
        # net = slim.max_pool2d(net, [3, 3], 2, scope='pool5')
        net = slim.conv2d(net, 512, [3, 3], scope='conv4')

        with slim.arg_scope([slim.conv2d],
                            weights_initializer=trunc_norm(0.1),
                            biases_initializer=tf.constant_initializer(0.1)):
            # net = slim.conv2d(net, 1028, [6, 6], padding='VALID', scope='fc6')
            net = slim.avg_pool2d(net, [6,6], stride=1,padding='VALID',scope='avg_pool5' )
            net = slim.dropout(net, droupout_keep_prob, is_training=is_training, scope='dropout6')
            # net = slim.conv2d(net, 512, [1, 1], scope='fc7')
            # net = slim.dropout(net, droupout_keep_prob, is_training=is_training, scope='dropout7')
            net = slim.conv2d(net, num_classes, [1, 1],
                              activation_fn=None,
                              normalizer_fn=None,
                              biases_initializer=tf.zeros_initializer(),
                              scope='fc7x')

        end_points = slim.utils.convert_collection_to_dict(end_points_collection)
        if spatial_squeze:
            net = tf.squeeze(net, [1,2], name='fc8/squeezed')
            end_points[sc.name + '/fc8'] = net
        return net, end_point

保存模型

varialbes_to_restore = slim.get_variables_to_restore()
saver = tf.train.Saver(varialbes_to_restore)
saver.save(sess,'I:/model/mnist/')

恢复模式

with tf.Session() as sess:
    logits, _ = _alex_slim.alex_net(teX[:200])
    saver = tf.train.Saver()
    saver.restore(sess, tf.train.latest_checkpoint('I:/model/mnist/'))
    logits_var = sess.run(logits)
    print(logits_var)

1 个答案:

答案 0 :(得分:0)

Perhapse tensorflow希望您将其保存在检查点中,因为您将其保存在范围内。 如果为未保存的张量赋予名称,则会出现相同的错误。

如果您不想保存它,请不要给它命名或范围。