按照标题-我可以使用对象检测API进行培训,但是当我看损耗曲线时,它太平滑了。在1个时期后的评估阶段,地面真相图像显示在Tensorboard中,但未在其上绘制框。我的数据集中没有负面示例,因此所有图像都应带有标签。
我已经测试了Oxford Pet数据集并且可以正常工作(统计数据看起来不错,并且图像在Tensorboard中显示为带有框)。我已经将我的TFRecords与模型回购提供的Pet脚本生成的TFRecords进行了比较,没有什么明显的。
我已使用以下功能将一组图像和边界框转换为TFRecords。它采用Darknet / Yolo格式标签(框中心x / y,框宽度,框高度(以标准化单位为单位))。所有图像都是1通道PNG文件(640x512),因此我将它们加载并将其转换为3通道。
如果您选择三个频道,我不确定decode_png
是否会自动执行此操作,但是我不想冒险,所以我先在OpenCV中进行转换。
def create_tf_example(path, names):
"""Creates a tf.Example proto from sample image
Returns:
example: The created tf.Example.
"""
annotations = load_annotation(path)
if annotations is None:
return
if len(annotations) == 0:
return
try:
with tf.gfile.GFile(path, 'rb') as fid:
image_data = fid.read()
# Force conversion to 3 channel just to be sure
image_cv = cv2.imdecode(np.fromstring(image_data, np.uint8), cv2.IMREAD_COLOR)
res, image_data = cv2.imencode('.png', image_cv)
image_data = image_data.tostring()
image_tensor = tf.image.decode_png(
image_data,
channels=3,
name=None
)
except:
print("Failed: ", path)
return
classes_text = []
classes = []
xmins = []
xmaxs = []
ymins = []
ymaxs = []
height = 512
width = 640
for a in annotations:
class_id, box_cx, box_cy, box_width, box_height = a
class_id = int(class_id)
if class_id < len(names):
xmin = max(0, float(box_cx - 0.5*box_width))
assert(xmin >= 0 and xmin <= 1)
xmax = min(1, float(box_cx + 0.5*box_width))
assert(xmax >= 0 and xmax <= 1)
ymin = max(0, float(box_cy - 0.5*box_height))
assert(ymin >= 0 and ymin <= 1)
ymax = min(1, float(box_cy + 0.5*box_height))
assert(ymax >= 0 and ymax <= 1)
xmins.append(xmin)
xmaxs.append(xmax)
ymins.append(ymin)
ymaxs.append(ymax)
classes.append(class_id+1)
classes_text.append(names[class_id].encode('utf8'))
if len(classes) is None:
print("Class out of range")
return
# Possible we've found annotations with invalid class IDs
if len(xmins) == 0:
return
tf_example = tf.train.Example(features=tf.train.Features(feature={
'image/height': int64_feature(height),
'image/width': int64_feature(width),
'image/filename': bytes_feature(os.path.basename(path).encode('utf8')),
'image/source_id': bytes_feature(os.path.basename(path).encode('utf8')),
'image/encoded': bytes_feature(image_data),
'image/format': bytes_feature('png'.encode('utf8')),
'image/object/bbox/xmin': float_list_feature(xmins),
'image/object/bbox/xmax': float_list_feature(xmaxs),
'image/object/bbox/ymin': float_list_feature(ymins),
'image/object/bbox/ymax': float_list_feature(ymaxs),
'image/object/class/text': bytes_list_feature(classes_text),
'image/object/class/label': int64_list_feature(classes),
}))
return tf_example
Darknet注释文件类似于(因此class_id获得+1):
0 0.251252 0.35801225 0.36322 0.25812092
这是一个示例测试:
raw_image_dataset = tf.data.TFRecordDataset('/home/josh/data/data/test.record-00000-of-00010')
# Create a dictionary describing the features.
image_feature_description = {
'image/height': tf.FixedLenFeature([], tf.int64),
'image/width': tf.FixedLenFeature([], tf.int64),
'image/encoded': tf.FixedLenFeature([],tf.string),
'image/object/bbox/xmax': tf.VarLenFeature(tf.float32),
'image/object/bbox/xmin': tf.VarLenFeature(tf.float32),
'image/object/bbox/ymin': tf.VarLenFeature(tf.float32),
'image/object/bbox/ymax': tf.VarLenFeature(tf.float32),
'image/object/class/text': tf.VarLenFeature(tf.string),
'image/object/class/label': tf.VarLenFeature(tf.int64),
}
def _parse_image_function(example_proto):
# Parse the input tf.Example proto using the dictionary above.
return tf.parse_single_example(example_proto, image_feature_description)
parsed_image_dataset = raw_image_dataset.map(_parse_image_function)
for im in parsed_image_dataset:
print(im['image/object/bbox/xmin'])
print(im['image/object/bbox/xmax'])
print(im['image/object/bbox/ymin'])
print(im['image/object/bbox/ymax'])
print(im['image/object/class/label'])
break
SparseTensor(indices=tf.Tensor(
[[0]
[1]], shape=(2, 1), dtype=int64), values=tf.Tensor([0.390625 0.4687505], shape=(2,), dtype=float32), dense_shape=tf.Tensor([2], shape=(1,), dtype=int64))
SparseTensor(indices=tf.Tensor(
[[0]
[1]], shape=(2, 1), dtype=int64), values=tf.Tensor([0.446875 0.5093755], shape=(2,), dtype=float32), dense_shape=tf.Tensor([2], shape=(1,), dtype=int64))
SparseTensor(indices=tf.Tensor(
[[0]
[1]], shape=(2, 1), dtype=int64), values=tf.Tensor([0.3923828 0.4685552], shape=(2,), dtype=float32), dense_shape=tf.Tensor([2], shape=(1,), dtype=int64))
SparseTensor(indices=tf.Tensor(
[[0]
[1]], shape=(2, 1), dtype=int64), values=tf.Tensor([0.4451172 0.5095708], shape=(2,), dtype=float32), dense_shape=tf.Tensor([2], shape=(1,), dtype=int64))
SparseTensor(indices=tf.Tensor(
[[0]
[1]], shape=(2, 1), dtype=int64), values=tf.Tensor([1 1], shape=(2,), dtype=int64), dense_shape=tf.Tensor([2], shape=(1,), dtype=int64))
我尝试了调试日志记录,但它仅显示迭代/损失。
答案 0 :(得分:0)
找出我的类标签,每个TFRecord的末尾都有换行符。发现这个过程比我愿意付出的更多努力!
我有一个正在读取的.names文件,例如
cat
dog
horse
除了不是,它是:
cat\n
dog\n
horse\n
,当我使用f.readlines()
打开名称文件以获取类标签时,我无意中将换行符存储在记录中。现在,Tensorflow实际上不应该使用该文本标签。它具有标签映射文件,并在模型中使用类ID。我怀疑正在发生的事情是,用于解析TFRecord的内容在class/text
字段中遇到了换行符,然后完全忽略了标签。
强烈建议任何这样做的人都使用类似的东西:
classes_text.append(names[class_id].rstrip().encode('utf8'))