我想使用TensorFlow 2数据集对象将图像馈送到CNN。我的图像位于AWS S3上,但在示例中,我将使用来自Wikipedia的图像(问题是相同的)。
image_urls = [
'https://upload.wikimedia.org/wikipedia/commons/6/60/Matterhorn_from_Domh%C3%BCtte_-_2.jpg',
'https://upload.wikimedia.org/wikipedia/commons/6/6e/Matterhorn_from_Klein_Matterhorn.jpg',
]
dataset = tf.data.Dataset.from_tensor_slices(image_urls)
def read_image_from_url(url):
img_array = None
with urlopen(url) as request:
img_array = np.asarray(bytearray(request.read()), dtype=np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) #as RGB image (cv2 is BGR by default)
当我使用数据集的一个元素测试函数时,它会起作用:
url = next(iter(dataset)).numpy().decode('utf-8')
img = read_image_from_url(url)
plt.imshow(img)
但是当我将函数映射到数据集以创建用于图像的新数据集时,它会失败:
dataset_images = dataset.map(lambda x: read_image_from_url(x.numpy().decode('utf-8')))
AttributeError: in converted code:
<ipython-input-6-e8eb89833196>:2 None *
map_func=lambda x: read_image_from_url(x.numpy().decode('utf-8')),
AttributeError: 'Tensor' object has no attribute 'numpy'
很明显,当用next
或map
进行迭代时,数据集提供了不同的dtype。知道我该如何解决吗?
答案 0 :(得分:3)
这比所需的难度要大:
import tensorflow as tf
import numpy as np
import cv2
from urllib.request import urlopen
import matplotlib.pyplot as plt
image_urls = [
'https://upload.wikimedia.org/wikipedia/commons/6/60/Matterhorn_from_Domh%C3%BCtte_-_2.jpg',
'https://upload.wikimedia.org/wikipedia/commons/6/6e/Matterhorn_from_Klein_Matterhorn.jpg',
]
dataset = tf.data.Dataset.from_tensor_slices(image_urls)
def get(url):
with urlopen(str(url.numpy().decode("utf-8"))) as request:
img_array = np.asarray(bytearray(request.read()), dtype=np.uint8)
img = cv2.imdecode(img_array, cv2.IMREAD_COLOR)
return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
def read_image_from_url(url):
return tf.py_function(get, [url], tf.uint8)
dataset_images = dataset.map(lambda x: read_image_from_url(x))
for d in dataset_images:
print(d)
为什么第一个工作成功,然后在tf.Dataset
中失败了? tf.Dataset
是在graph mode
中定义的,而不像第一个那样在eager mode
中定义的。图形模式更快,并且tf.Dataset
已针对速度进行了优化,因此很有意义。在图形模式下,您无法执行.numpy()
,因为所有操作都应在tensorflow
操作中定义。 py_func
将Python函数包装在tf.Operation
中执行的eager mode
中,这正是我们所需要的。
注意:我尝试过tf.keras.utils.get_file()
,但是遇到了您在此描述的类似问题。希望这会有所帮助!