我最近开始研究使用Tensorflow的CNN,发现tfrecords对加快训练速度非常有帮助,但是我在使用数据API时遇到了麻烦。
解析之后,我的数据集由(图像,标签)元组组成,这很适合训练,但是我试图将图像提取到另一个数据集中以调用keras.predict()。
我已经尝试过此解决方案:
test_set = get_set_tfrecord(test_path, _parse_function, num_parallel_calls = 4)
lab = []
f = True
for image, label in test_set.take(600):
if f:
img = tf.data.Dataset.from_tensors(image)
f = False
else:
img = img.concatenate(tf.data.Dataset.from_tensors(image))
lab.append(label.numpy())
天真,不是很好的代码,但是它可以执行级联(即堆叠),但它可以将每个图像加载到RAM中。
这样做的正确方法是什么?
答案 0 :(得分:1)
您可以使用map
中的tf.data.Dataset
API。您可以编写以下代码。
result = test_set.map(lambda image, label: image)
# You can iterate and check what you have received at the end.
# I expect only the images.
for image in result.take(1):
print(image)
我希望使用上面的代码可以解决问题,并且此答案对您有帮助。