tf.data.Dataset是否支持生成字典结构?

时间:2018-01-17 23:10:00

标签: tensorflow tensorflow-datasets tensorflow-estimator

以下是[https://www.tensorflow.org/programmers_guide/datasets]的一段代码。在此示例中,map函数是用于读取数据的用户定义函数。在map函数中,我们需要设置输出类型为[tf.uint8, label.dtype]

import cv2

# Use a custom OpenCV function to read the image, instead of the standard
# TensorFlow `tf.read_file()` operation.
def _read_py_function(filename, label):
  image_decoded = cv2.imread(image_string, cv2.IMREAD_GRAYSCALE)
  return image_decoded, label

# Use standard TensorFlow operations to resize the image to a fixed shape.
def _resize_function(image_decoded, label):
  image_decoded.set_shape([None, None, None])
  image_resized = tf.image.resize_images(image_decoded, [28, 28])
  return image_resized, label

  filenames = ["/var/data/image1.jpg", "/var/data/image2.jpg", ...]
  labels = [0, 37, 29, 1, ...]

dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.map(
  lambda filename, label: tuple(tf.py_func(
    _read_py_function, [filename, label], [tf.uint8, label.dtype])))
dataset = dataset.map(_resize_function)

我的问题是,如果我们想要_read_py_function()输出一个Python字典,那么我们如何设置outptu类型?是否有继承数据类型,例如tf.dict?例如:

def _read_py_function(filename):
  image_filename = filename[0]
  label_filename = filename[1]
  image_id = filename[2]
  image_age = filename[3]
  image_decoded = cv2.imread(image_filename, cv2.IMREAD_GRAYSCALE)
  image_decoded = cv2.imread(label_fielname, cv2.IMREAD_GRAYSCALE)
  return {'image':image_decoded, 'label':label_decoded, 'id':image_id, 'age':image_age}

然后,我们如何设计dataset.map()函数?

2 个答案:

答案 0 :(得分:3)

tf.data.Dataset.map调用的函数内返回dicts应该按预期工作。

以下是一个例子:

dataset = tf.data.Dataset.range(10)
dataset = dataset.map(lambda x: {'a': x, 'b': 2 * x})
dataset = dataset.map(lambda y: y['a'] + y['b'])

res = dataset.make_one_shot_iterator().get_next()

with tf.Session() as sess:
    for i in range(10):
        assert sess.run(res) == 3 * i

答案 1 :(得分:0)

要添加到上述答案中,这也可以:

dataset = tf.data.Dataset.range(10)
dataset = dataset.map(lambda x: {'a': x, 'b': 2 * x})

res = dataset.make_one_shot_iterator().get_next()

with tf.Session() as sess:
    for i in range(10):
        curr_res = sess.run(res)
        assert curr_res['a'] == i
        assert curr_res['b'] == 2 * i