我正在使用tensorflow slim resnet_v2来提取图像功能。 resnet_v2_152.ckpt来自:resnet_v2_152.ckpt 这是我的代码。
import tensorflow as tf
import tensorflow.contrib.slim.python.slim.nets.resnet_v2 as resnet_v2
def cnn_model_fn(features, labels, mode):
net, end_points = resnet_v2.resnet_v2_152(inputs=features, is_training=mode == tf.estimator.ModeKeys.TRAIN)
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode, predictions=net)
else:
raise NotImplementedError('only support predict!')
def parse_filename(filename):
image_string = tf.read_file(filename)
image_decoded = tf.image.decode_jpeg(image_string, channels=3)
image_resized = tf.image.resize_images(image_decoded, [256, 256])
return image_resized
def dataset_input_fn(dataset, num_epochs=None, batch_size=128, shuffle=False, buffer_size=1000, seed=None):
def input_fn():
d = dataset.repeat(num_epochs).batch(batch_size)
if shuffle:
d = d.shuffle(buffer_size)
iterator = d.make_one_shot_iterator()
next_example = iterator.get_next()
return next_example
return input_fn
filenames = sorted(tf.gfile.Glob('/root/data/COCO/download/val2014/*'))
dataset = tf.contrib.data.Dataset.from_tensor_slices(filenames).map(parse_filename)
input_fn = dataset_input_fn(dataset, num_epochs=1, batch_size=1, shuffle=False)
estimator = tf.estimator.Estimator(model_fn=cnn_model_fn, model_dir=None)
es = estimator.predict(input_fn=input_fn,
checkpoint_path='/root/data/checkpoints/resnet_v2_152_2017_04_14/resnet_v2_152.ckpt')
print(es.__next__())
print("Done!")
我得到了这样的错误:
2017-09-10 22:06:36.875590: W tensorflow/core/framework/op_kernel.cc:1192] Not found: Tensor name "resnet_v2_152/block1/unit_1/bottleneck_v2/conv1/biases" not found in checkpoint files /root/data/checkpoints/resnet_v2_152_2017_04_14/resnet_v2_152.ckpt
[[Node: save/RestoreV2_1 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_1/tensor_names, save/RestoreV2_1/shape_and_slices)]]
Traceback (most recent call last):
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1327, in _do_call
return fn(*args)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/client/session.py", line 1306, in _run_fn
status, run_metadata)
File "/usr/lib/python3.5/contextlib.py", line 66, in __exit__
next(self.gen)
File "/usr/local/lib/python3.5/dist-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status
pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.NotFoundError: Tensor name "resnet_v2_152/block1/unit_1/bottleneck_v2/conv1/biases" not found in checkpoint files /root/data/checkpoints/resnet_v2_152_2017_04_14/resnet_v2_152.ckpt
[[Node: save/RestoreV2_1 = RestoreV2[dtypes=[DT_FLOAT], _device="/job:localhost/replica:0/task:0/cpu:0"](_arg_save/Const_0_0, save/RestoreV2_1/tensor_names, save/RestoreV2_1/shape_and_slices)]]
[[Node: save/RestoreV2_242/_309 = _Recv[client_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:0", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=1, tensor_name="edge_1240_save/RestoreV2_242", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"]()]]
我想我可以通过将conv1 / biases初始化为0来解决这个问题,但是tensorflow Estimator并没有给我这样的功能。我该如何解决这个问题?
答案 0 :(得分:1)
我认为,您期望加载预训练的权重,而不仅仅是在resnet中初始化变量。您应该考虑使用tf.train.Scaffold对象。
模型例程应如下所示
def cnn_model_fn(features, labels, mode):
with slim.arg_scope(resnet_v2.resnet_arg_scope()):
logits, end_points = resnet_v2.resnet_v2_152(features,
is_training=mode == tf.estimator.ModeKeys.TRAIN)
checkpoint_file = 'resnet_v2_152.ckpt'
init_fn = slim.assign_from_checkpoint_fn(
checkpoint_file,
[var for var in tf.global_variables()])
saver = tf.train.Saver(max_to_keep=10)
scaffold = tf.train.Scaffold(init_fn=init_fn, saver=saver)
if mode == tf.estimator.ModeKeys.PREDICT:
return tf.estimator.EstimatorSpec(mode=mode,
predictions={'logits': logits},
scaffold=scaffold)
else:
raise NotImplementedError('only support predict!')