我正在尝试在我自己的数据集上训练tensorflow对象检测。
我做了什么?
使用ssd_mobilenet_v1_pets.config
作为基础来创建自己的管道配置。修改了num_classes
和所有其他路径特定部分以匹配我的环境。
来自tensorflow model zoo的ssd_mobilenet_v1_coco作为检查点
使用所有标签创建标签地图文件(第一个索引从1开始)
从我的数据集创建了一个TFRecord
文件(该脚本基于tensorflow sample script)
出了什么问题?
开始培训时:
python tensorflow_models/research/object_detection/train.py --pipeline_config_path=/home/playground/ssd_mobilenet_v1.config --train_dir=/tmp/bla/
我得到以下回溯:
Traceback (most recent call last):
File "tensorflow_models/research/object_detection/train.py", line 198, in <module>
tf.app.run()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "tensorflow_models/research/object_detection/train.py", line 194, in main
worker_job_name, is_chief, FLAGS.train_dir)
File "/home/playground/tensorflow_models/research/object_detection/trainer.py", line 296, in train
saver=saver)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/slim/python/slim/learning.py", line 767, in train
sv.stop(threads, close_summary_writer=True)
File "/usr/lib/python2.7/contextlib.py", line 35, in __exit__
self.gen.throw(type, value, traceback)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/supervisor.py", line 964, in managed_session
self.stop(close_summary_writer=close_summary_writer)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/supervisor.py", line 792, in stop
stop_grace_period_secs=self._stop_grace_secs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/coordinator.py", line 389, in join
six.reraise(*self._exc_info_to_raise)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/queue_runner_impl.py", line 238, in _run
enqueue_callable()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1235, in _single_operation_run
target_list_as_strings, status, None)
File "/usr/lib/python2.7/contextlib.py", line 24, in __exit__
self.gen.next()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/errors_impl.py", line 466, in raise_exception_on_not_ok_status
pywrap_tensorflow.TF_GetCode(status))
tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[0] = 2 is not in [0, 1)
[[Node: cond/RandomCropImage/PruneCompleteleyOutsideWindow/Gather/Gather_1 = Gather[Tindices=DT_INT64, Tparams=DT_INT64, validate_indices=true, _device="/job:localhost/replica:0/task:0/cpu:0"](cond/RandomCropImage/PruneCompleteleyOutsideWindow/Gather/Gather_1/Switch:1, cond/RandomCropImage/PruneCompleteleyOutsideWindow/Reshape)]]
不幸的是,我不知道张量流是什么想告诉我这个追溯,也不知道我应该从哪里开始寻找我的错误。我已经检查了每一步可能出现的错误,但到目前为止还找不到任何错误。
编辑:我也尝试使用this配置文件,正如@eshirima所提议的那样。我再次更改了num_classes
参数以及标有PATH_TO_BE_CONFIGURED
的所有其他参数。但是,现在它失败并显示以下错误消息:
INFO:tensorflow:Starting Queues.
INFO:tensorflow:global_step/sec: 0
INFO:tensorflow:Error reported to Coordinator: <class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>, indices[0] = 2 is not in [0, 1)
[[Node: Loss/Gather_29 = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, validate_indices=true, _device="/job:localhost/replica:0/task:0/cpu:0"](Loss/Pad_5, Loss/Reshape_47)]]
Caused by op u'Loss/Gather_29', defined at:
File "tensorflow_models/research/object_detection/train.py", line 198, in <module>
tf.app.run()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "tensorflow_models/research/object_detection/train.py", line 194, in main
worker_job_name, is_chief, FLAGS.train_dir)
File "/home/playground/tensorflow_models/research/object_detection/trainer.py", line 192, in train
clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue])
File "/home/playground/tensorflow_models/research/slim/deployment/model_deploy.py", line 193, in create_clones
outputs = model_fn(*args, **kwargs)
File "/home/playground/tensorflow_models/research/object_detection/trainer.py", line 133, in _create_losses
losses_dict = detection_model.loss(prediction_dict)
File "/home/playground/tensorflow_models/research/object_detection/meta_architectures/ssd_meta_arch.py", line 411, in loss
self.groundtruth_lists(fields.BoxListFields.classes))
File "/home/playground/tensorflow_models/research/object_detection/meta_architectures/ssd_meta_arch.py", line 485, in _assign_targets
groundtruth_classes_with_background_list)
File "/home/playground/tensorflow_models/research/object_detection/core/target_assigner.py", line 438, in batch_assign_targets
anchors, gt_boxes, gt_class_targets)
File "/home/playground/tensorflow_models/research/object_detection/core/target_assigner.py", line 154, in assign
match)
File "/home/playground/tensorflow_models/research/object_detection/core/target_assigner.py", line 250, in _create_classification_targets
matched_cls_targets = tf.gather(groundtruth_labels, matched_gt_indices)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/array_ops.py", line 2409, in gather
validate_indices=validate_indices, name=name)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 1219, in gather
validate_indices=validate_indices, name=name)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
op_def=op_def)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2630, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1204, in __init__
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
InvalidArgumentError (see above for traceback): indices[0] = 2 is not in [0, 1)
[[Node: Loss/Gather_29 = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, validate_indices=true, _device="/job:localhost/replica:0/task:0/cpu:0"](Loss/Pad_5, Loss/Reshape_47)]]
Traceback (most recent call last):
File "tensorflow_models/research/object_detection/train.py", line 198, in <module>
tf.app.run()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "tensorflow_models/research/object_detection/train.py", line 194, in main
worker_job_name, is_chief, FLAGS.train_dir)
File "/home/playground/tensorflow_models/research/object_detection/trainer.py", line 296, in train
saver=saver)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/contrib/slim/python/slim/learning.py", line 767, in train
sv.stop(threads, close_summary_writer=True)
File "/usr/lib/python2.7/contextlib.py", line 35, in __exit__
self.gen.throw(type, value, traceback)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/supervisor.py", line 964, in managed_session
self.stop(close_summary_writer=close_summary_writer)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/supervisor.py", line 792, in stop
stop_grace_period_secs=self._stop_grace_secs)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/coordinator.py", line 389, in join
six.reraise(*self._exc_info_to_raise)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/coordinator.py", line 296, in stop_on_exception
yield
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/coordinator.py", line 494, in run
self.run_loop()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/training/supervisor.py", line 994, in run_loop
self._sv.global_step])
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 895, in run
run_metadata_ptr)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1124, in _run
feed_dict_tensor, options, run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1321, in _do_run
options, run_metadata)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/client/session.py", line 1340, in _do_call
raise type(e)(node_def, op, message)
tensorflow.python.framework.errors_impl.InvalidArgumentError: indices[0] = 2 is not in [0, 1)
[[Node: Loss/Gather_29 = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, validate_indices=true, _device="/job:localhost/replica:0/task:0/cpu:0"](Loss/Pad_5, Loss/Reshape_47)]]
Caused by op u'Loss/Gather_29', defined at:
File "tensorflow_models/research/object_detection/train.py", line 198, in <module>
tf.app.run()
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/platform/app.py", line 48, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "tensorflow_models/research/object_detection/train.py", line 194, in main
worker_job_name, is_chief, FLAGS.train_dir)
File "/home/playground/tensorflow_models/research/object_detection/trainer.py", line 192, in train
clones = model_deploy.create_clones(deploy_config, model_fn, [input_queue])
File "/home/playground/tensorflow_models/research/slim/deployment/model_deploy.py", line 193, in create_clones
outputs = model_fn(*args, **kwargs)
File "/home/playground/tensorflow_models/research/object_detection/trainer.py", line 133, in _create_losses
losses_dict = detection_model.loss(prediction_dict)
File "/home/playground/tensorflow_models/research/object_detection/meta_architectures/ssd_meta_arch.py", line 411, in loss
self.groundtruth_lists(fields.BoxListFields.classes))
File "/home/playground/tensorflow_models/research/object_detection/meta_architectures/ssd_meta_arch.py", line 485, in _assign_targets
groundtruth_classes_with_background_list)
File "/home/playground/tensorflow_models/research/object_detection/core/target_assigner.py", line 438, in batch_assign_targets
anchors, gt_boxes, gt_class_targets)
File "/home/playground/tensorflow_models/research/object_detection/core/target_assigner.py", line 154, in assign
match)
File "/home/playground/tensorflow_models/research/object_detection/core/target_assigner.py", line 250, in _create_classification_targets
matched_cls_targets = tf.gather(groundtruth_labels, matched_gt_indices)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/array_ops.py", line 2409, in gather
validate_indices=validate_indices, name=name)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/ops/gen_array_ops.py", line 1219, in gather
validate_indices=validate_indices, name=name)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/op_def_library.py", line 767, in apply_op
op_def=op_def)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2630, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1204, in __init__
self._traceback = self._graph._extract_stack() # pylint: disable=protected-access
InvalidArgumentError (see above for traceback): indices[0] = 2 is not in [0, 1)
[[Node: Loss/Gather_29 = Gather[Tindices=DT_INT32, Tparams=DT_FLOAT, validate_indices=true, _device="/job:localhost/replica:0/task:0/cpu:0"](Loss/Pad_5, Loss/Reshape_47)]]
编辑添加了一些代码,显示了TFRecord
文件的生成方式。整个脚本有点长,但我试图将其剪切为仅显示相关部分。如果遗忘了您感兴趣的内容,请告诉我。
CATEGORIES_TO_TRAIN = ["apple", "dog", "cat"]
def createTFExample(img):
imageFormat = ""
if img.format == 'JPEG':
imageFormat = b'jpeg'
elif img.format == 'PNG':
imageFormat = b'png'
else:
print 'Unknown Image format %s' %(img.format,)
return None
width, height = img.size
filename = str(img.filename)
encodedImageData = img.bytesIO
xmins = []
xmaxs = []
ymins = []
ymaxs = []
for annotation in img.annotations:
xmins.append((annotation.left / width))
xmaxs.append((annotation.left + annotation.width) / width)
ymins.append((annotation.top / height))
ymaxs.append((annotation.top + annotation.height) / height)
#we might have some images in our dataset, which don't have a annotation, skip those
if((len(xmins) == 0) or (len(xmaxs) == 0) or (len(ymins) == 0) or (len(ymaxs) == 0)):
return None
label = [img.label.encode('utf8')]
classes = [(CATEGORIES_TO_TRAIN.index(img.label) + 1)] #class indexes start with 1
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(filename),
'image/source_id': dataset_util.bytes_feature(filename),
'image/encoded': dataset_util.bytes_feature(encodedImageData),
'image/format': dataset_util.bytes_feature(imageFormat),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
'image/object/class/text': dataset_util.bytes_list_feature(label),
'image/object/class/label': dataset_util.int64_list_feature(classes),
}))
return tf_example
def createTfRecordFile(images):
writer = tf.python_io.TFRecordWriter(TFRECORD_OUTPUT_PATH)
for img in images:
t = createTFExample(img)
if t is not None:
writer.write(t.SerializeToString())
writer.close()
任何帮助我指向正确方向的帮助都非常感谢!
答案 0 :(得分:2)
我有一个类似的问题,但让label
列表和classes
列表具有相同的长度,边界框元素为我修复了它。
具体来说,在createTFExample()
中,label = [img.label.encode('utf8')]
和classes = [(CATEGORIES_TO_TRAIN.index(img.label) + 1)]
中的元素应对应于边界框注释列表的元素:
xmins = []
xmaxs = []
ymins = []
ymaxs = []
for annotation in img.annotations:
xmins.append((annotation.left / width))
xmaxs.append((annotation.left + annotation.width) / width)
ymins.append((annotation.top / height))
ymaxs.append((annotation.top + annotation.height) / height)
从你的代码结构我假设每个img
对象有一个对象类型,但在这种情况下,写
label = [img.label.encode('utf8')] * len(xmins)
classes = [(CATEGORIES_TO_TRAIN.index(img.label) + 1)] * len(xmins)
或使用任何能够为您提供图像中对象数量的元素,以便标签&amp;类和边界框列表具有相同的长度。
如果img
对象中有多种类型的对象,则创建一个对象名称和类别ID列表,其中内部元素的索引与注释列表的索引匹配。
结果列表应如下所示:
xmins = [a_xmin, b_xmin, c_xmin]
ymins = [a_ymin, b_ymin, c_ymin]
xmaxs = [a_xmax, b_xmax, c_xmax]
ymaxs = [a_ymax, b_ymax, c_ymax]
labels = [a_label, b_label, c_label]
classes = [a_classid, b_classid, c_classid]
这让我的问题消失了,希望这有用!