如何将地图功能应用于tf.Tensor

时间:2020-05-29 05:21:16

标签: python tensorflow tensorflow2.0 tensor tensorflow-datasets

dataset = tf.data.Dataset.from_tensor_slices((images,boxes))
function_to_map = lambda x,y: func3(x,y)
fast_benchmark(dataset.map(function_to_map).batch(1).prefetch(tf.data.experimental.AUTOTUNE))

现在我是func3

def fast_benchmark(dataset, num_epochs=2):
    start_time = time.perf_counter()
    print('dataset->',dataset)
    for _ in tf.data.Dataset.range(num_epochs):
        for _,__ in dataset:
            print(_,__)
            break
            pass

印刷品的输出是

tf.Tensor([b'/media/jake/mark-4tb3/input/datasets/pascal/VOCtrainval_11-May-2012/VOCdevkit/VOC2012/JPEGImages/2008_000008.jpg'], shape=(1,), dtype=string) <tf.RaggedTensor [[[52, 86, 470, 419], [157, 43, 288, 166]]]>

我想在func3()中做什么
想要将图像目录更改为真实图像并运行批处理

1 个答案:

答案 0 :(得分:1)

您需要从张量中提取字符串并使用适当的图像读取功能。以下是代码中要实现此目的的步骤。

  1. 您必须使用tf.py_function(get_path, [x], [tf.float32])装饰地图功能。您可以找到有关tf.py_function here的更多信息。在tf.py_function中,第一个参数是map函数的名称,第二个参数是要传递给map函数的元素,最后一个参数是返回类型。
  2. 您可以通过在地图函数中使用bytes.decode(file_path.numpy())来获取字符串部分。
  3. 使用适当的功能加载图像。我们正在使用load_img

在下面的简单程序中,我们使用tf.data.Dataset.list_files来读取图像的路径。接下来,在map函数中,我们将使用load_img读取图像,然后执行tf.image.central_crop函数来裁剪图像的中心部分。

代码-

%tensorflow_version 2.x
import tensorflow as tf
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array, array_to_img
from matplotlib import pyplot as plt
import numpy as np

def load_file_and_process(path):
    image = load_img(bytes.decode(path.numpy()), target_size=(224, 224))
    image = img_to_array(image)
    image = tf.image.central_crop(image, np.random.uniform(0.50, 1.00))
    return image

train_dataset = tf.data.Dataset.list_files('/content/bird.jpg')
train_dataset = train_dataset.map(lambda x: tf.py_function(load_file_and_process, [x], [tf.float32]))

for f in train_dataset:
  for l in f:
    image = np.array(array_to_img(l))
    plt.imshow(image)

输出-

enter image description here

希望这能回答您的问题。学习愉快。