tfrecords的稀疏数据集

时间:2018-12-05 10:25:49

标签: tensorflow object-detection object-detection-api tfrecord

我有一个稀疏的数据集,其中一些图像没有任何感兴趣的对象实例。我想将此数据集用于对象检测,因此我需要生成tfrecords。

对于那些没有检测到的图像,数据集中每个图像仍然具有注释文件,所有条目都填充有零:instance_id= 0, class_id = 0.0, x = 0.0, y = 0.0, w= 0.0, h=0.0等。

在我的课程标签图中,标签ID从1开始,因为这是Tensorflow的要求(标签ID 0由Tensorflow保留用于背景)

现在我不明白如何创建tfrecords。当我尝试使用所有注释文件时,class_id=0出现一个关键错误,因为它不是标签内映射。

我不认为创建tfrecords时可能不对样本使用任何注释文件,因此我需要将此抽象注释与零一起使用。有关如何解决此问题的任何想法 ?

label_map.pbtxt

    item {
      id: 1
      name: '1.0'
      display_name: 'table'
    }

    item {
      id: 2
      name: '2.0'
      display_name: 'chair'
    }

    item {
      id: 3
      name: '3.0'
      display_name: 'window'
    }


def prepare_example(image_path, annotations, label_map_dict):
  """Converts a dictionary with annotations for an image to tf.Example proto.

  Args:
    image_path: The complete path to image.
    annotations: A dictionary representing the annotation of a single object
      that appears in the image.
    label_map_dict: A map from string label names to integer ids.

  Returns:
    example: The converted tf.Example.
  """
  with tf.gfile.GFile(image_path, 'rb') as fid:
    encoded_png = fid.read()
  encoded_png_io = io.BytesIO(encoded_png)
  image = pil.open(encoded_png_io)
  image = np.asarray(image)

  key = hashlib.sha256(encoded_png).hexdigest()

  width = int(image.shape[1])
  height = int(image.shape[0])

  xmin_norm = annotations['x1'] / float(width)
  ymin_norm = annotations['y1'] / float(height)
  xmax_norm = (annotations['x1'] + annotations['w']) / float(width)
  ymax_norm = (annotations['y1'] + annotations['h']) / float(height)

  difficult_obj = [0]*len(xmin_norm)


  example = tf.train.Example(features=tf.train.Features(feature={
      'image/height': dataset_util.int64_feature(height),
      'image/width': dataset_util.int64_feature(width),
      'image/filename': dataset_util.bytes_feature(image_path.encode('utf8')),
      'image/source_id': dataset_util.bytes_feature(image_path.encode('utf8')),
      'image/key/sha256': dataset_util.bytes_feature(key.encode('utf8')),
      'image/encoded': dataset_util.bytes_feature(encoded_png),
      'image/format': dataset_util.bytes_feature('png'.encode('utf8')),
      'image/object/bbox/xmin': dataset_util.float_list_feature(xmin_norm),
      'image/object/bbox/xmax': dataset_util.float_list_feature(xmax_norm),
      'image/object/bbox/ymin': dataset_util.float_list_feature(ymin_norm),
      'image/object/bbox/ymax': dataset_util.float_list_feature(ymax_norm),
      'image/object/class/label': dataset_util.bytes_list_feature(
          [x.encode('utf8') for x in annotations['class_label']]),
      'image/object/class/text': dataset_util.int64_list_feature(
          [label_map_dict[x] for x in annotations['class_label']]),
      'image/object/difficult': dataset_util.int64_list_feature(difficult_obj),
  }))

  return example

def read_annotation_file(filename):
  """Reads a CN annotation file.

  Converts a CN annotation file into a dictionary containing all the
  relevant information.

  Args:
    filename: the path to the annotataion text file.

  Returns:
    anno: A dictionary with the converted annotation information. See annotation
    README file for details on the different fields.
  """
  f = open(filename)
  reader = csv.reader(f)
  next(reader)
  content = reader# skip header
  content = [x for x in content][0]

  anno = {}
  anno['inst_id'] = np.array([float(content[0])])
  anno['class_label'] = np.array([content[1]])

  anno['x1'] = np.array([float(content[2])])
  anno['y1'] = np.array([float(content[3])])

  anno['w'] = np.array([float(content[4])])
  anno['h'] = np.array([float(content[5])])

  return anno

0 个答案:

没有答案