从张量流记录数据集中提取图像数据

时间:2020-07-20 19:03:24

标签: python tensorflow machine-learning keras computer-vision

我最近开始研究使用Tensorflow的CNN,发现tfrecords对加快训练速度非常有帮助,但是我在使用数据API时遇到了麻烦。
解析之后,我的数据集由(图像,标签)元组组成,这很适合训练,但是我试图将图像提取到另一个数据集中以调用keras.predict()。

我已经尝试过此解决方案:

test_set = get_set_tfrecord(test_path, _parse_function, num_parallel_calls = 4)

lab = []
f = True
for image, label in test_set.take(600):
    if f:
      img = tf.data.Dataset.from_tensors(image)
      f = False
    else:
      img = img.concatenate(tf.data.Dataset.from_tensors(image))
    lab.append(label.numpy())

天真,不是很好的代码,但是它可以执行级联(即堆叠),但它可以将每个图像加载到RAM中。

这样做的正确方法是什么?

1 个答案:

答案 0 :(得分:1)

您可以使用map中的tf.data.Dataset API。您可以编写以下代码。

result = test_set.map(lambda image, label: image)
# You can iterate and check what you have received at the end.
# I expect only the images.
for image in result.take(1):
    print(image)

我希望使用上面的代码可以解决问题,并且此答案对您有帮助。