我使用tensorflow对象检测api为微调任务创建数据集。
我的目录结构是:
火车/
- imgs /
---- img1.jpg
- ann /
---- img1.csv
其中csv(每个图像一个)为label, x, y, w, h
我用这个脚本来保存tfrecord:
import tensorflow as tf
from os import listdir
import os
from os.path import isfile, join
import csv
import json
from object_detection.utils import dataset_util
flags = tf.app.flags
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
FLAGS = flags.FLAGS
LABEL_DICT = {}
counter = 0
def create_tf_example(example):
# TODO(user): Populate the following variables from your example.
height = 404 # Image height
width = 720 # Image width
filename = example['path'].encode('utf-8').strip() # Filename of the image. Empty if image is not from file
with tf.gfile.GFile(example['path'], 'rb') as fid:
encoded_image_data = fid.read()
image_format = 'jpeg'.encode('utf-8').strip() # b'jpeg' or b'png'
xmins = [] # List of normalized left x coordinates in bounding box (1 per box)
xmaxs = [] # List of normalized right x coordinates in bounding box
# (1 per box)
ymins = [] # List of normalized top y coordinates in bounding box (1 per box)
ymaxs = [] # List of normalized bottom y coordinates in bounding box
# (1 per box)
classes_text = [] # List of string class name of bounding box (1 per box)
classes = [] # List of integer class id of bounding box (1 per box)
for box in example['boxes']:
#if box['occluded'] is False:
#print("adding box")
xmins.append(float(int(box['x']) / width))
xmaxs.append(float(int(box['w']) + int(box['x']) / width))
ymins.append(float(int(box['y']) / height))
ymaxs.append(float(int(box['h']) + int(box['y']) / height))
classes_text.append(box['label'].encode('utf-8'))
classes.append(int(LABEL_DICT[box['label']]))
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': dataset_util.int64_feature(height),
'image/width': dataset_util.int64_feature(width),
'image/filename': dataset_util.bytes_feature(filename),
'image/source_id': dataset_util.bytes_feature(filename),
'image/encoded': dataset_util.bytes_feature(encoded_image_data),
'image/format': dataset_util.bytes_feature(image_format),
'image/object/bbox/xmin': dataset_util.float_list_feature(xmins),
'image/object/bbox/xmax': dataset_util.float_list_feature(xmaxs),
'image/object/bbox/ymin': dataset_util.float_list_feature(ymins),
'image/object/bbox/ymax': dataset_util.float_list_feature(ymaxs),
'image/object/class/text': dataset_util.bytes_list_feature(classes_text),
'image/object/class/label': dataset_util.int64_list_feature(classes),
}))
return tf_example
def ex_info(img_path, ann_path):
boxes = []
head = ['label','x','y','w','h']
with open(ann_path, 'r') as csvfile:
annreader = csv.DictReader(csvfile, fieldnames=head)
for box in annreader:
boxes.append(box)
LABEL_DICT[box['label']] = LABEL_DICT.get(box['label'], len(LABEL_DICT) + 1)
ex = {
"path" : img_path,
"boxes" : boxes
}
return ex
def main(_):
writer = tf.python_io.TFRecordWriter(FLAGS.output_path)
# TODO(user): Write code to read in your dataset to examples variable
dataset_dir = "train"
ann_dir = join(dataset_dir, "ann")
imgs_dir = join(dataset_dir, "imgs")
labelDest = "tfTrain/data/labels_map.pbtxt"
imgs = [join(imgs_dir, f) for f in listdir(imgs_dir) if isfile(join(imgs_dir, f))]
anns = [join(ann_dir, os.path.basename(im).replace("jpg","csv")) for im in imgs]
for img,ann in zip(imgs,anns):
example = ex_info(img,ann)
#tf_example = create_tf_example(example)
#writer.write(tf_example.SerializeToString())
with open(labelDest, 'w', encoding='utf-8') as outL:
for name,key in LABEL_DICT.items():
outL.write("item { \n id: " + str(key) + "\n name: '" + name + "'\n}\n")
writer.close()
if __name__ == '__main__':
tf.app.run()
但是当我运行火车脚本时我得到了这个错误
python train.py --logtostderr --train_dir =。/ models / train --pipeline_config_path = faster_rcnn_resnet101_coc o.config
警告:tensorflow:来自models / research / object_detection / trainer.py:257:create_global_step (来自tensorflow.contrib.framewo rk.python.ops.variables)已弃用,将在以后的版本中删除。 更新说明: 请切换到tf.train.create_global_step Traceback(最近一次调用最后一次): 文件" models / research / object_detection / utils / label_map_util.py",第135行, 在load_labelmap中 text_format.Merge(label_map_string,label_map) 文件" /home/user/anaconda3/envs/tf/lib/python3.6/site-packages/google/protobuf/text_format.py", 第525行,在Merge descriptor_pool = descriptor_pool) 文件" /home/user/anaconda3/envs/tf/lib/python3.6/site-packages/google/protobuf/text_format.py", 第579行,在MergeLines中 return parser.MergeLines(lines,message) 文件" /home/user/anaconda3/envs/tf/lib/python3.6/site-packages/google/protobuf/text_format.py", 第612行,在MergeLines中 self._ParseOrMerge(行,消息) 文件" /home/user/anaconda3/envs/tf/lib/python3.6/site-packages/google/protobuf/text_format.py", 第627行,在_ParseOrMerge中 self._MergeField(tokenizer,message) 文件" /home/user/anaconda3/envs/tf/lib/python3.6/site-packages/google/protobuf/text_format.py", 第787行,在_MergeField中 合并(标记器,消息,字段) 文件" /home/user/anaconda3/envs/tf/lib/python3.6/site-packages/google/protobuf/text_format.py", 第815行,在_MergeMessageField中 self._MergeField(tokenizer,sub_message) 文件" /home/user/anaconda3/envs/tf/lib/python3.6/site-packages/google/protobuf/text_format.py", 第695行,在_MergeField中 (message_descriptor.full_name,name)) google.protobuf.text_format.ParseError:23:20:消息类型" object_detection.protos.StringIntLabelMapItem"没有命名的字段 " S"在处理上述异常期间,发生了另一个异常:
Traceback (most recent call last): File "train.py", line 184, in <module> tf.app.run() File "/home/user/anaconda3/envs/tf/lib/python3.6/site-packages/tensorflow/python/platform/app.py",
第126行,在运行中 _sys.exit(主(argv的)) 文件&#34; train.py&#34;,第180行,在main中 graph_hook_fn = graph_rewriter_fn) 文件&#34; models / research / object_detection / trainer.py&#34;,第264行,列车 train_config.prefetch_queue_capacity,data_augmentation_options) 文件&#34; models / research / object_detection / trainer.py&#34;,第59行,在create_input_queue中 tensor_dict = create_tensor_dict_fn() 在get_next中输入文件&#34; train.py&#34;,第121行 dataset_builder.build(配置))。get_next() File&#34; models / research / object_detection / builders / dataset_builder.py&#34;,line 155,在构建中 label_map_proto_file = label_map_proto_file) 文件&#34; models / research / object_detection / data_decoders / tf_example_decoder.py&#34;, 第245行,在 init 中 use_display_name) 文件&#34; models / research / object_detection / utils / label_map_util.py&#34;,第152行, 在get_label_map_dict中 label_map = load_labelmap(label_map_path) 文件&#34; models / research / object_detection / utils / label_map_util.py&#34;,第137行, 在load_labelmap中 label_map.ParseFromString(label_map_string) TypeError:需要类似字节的对象,而不是&#39; str&#39;
我不明白这个问题是什么。在tfrecord?在里面 labels.pbtxt?还是在配置文件中?
答案 0 :(得分:1)
好的,我刚刚解决了调试张量流问题。虽然采用utf-8格式,但是我的标签很难被张量流读取,因为有些奇怪的字符如&amp; ùà。从csv中删除让火车开始