将列表项映射到tensorflow数据集字典

时间:2020-04-02 07:12:56

标签: python tensorflow dictionary mapping dictionary-comprehension

我正在尝试将图像信息映射到由图像和标签字典组成的数据集。

parse_function()应该仅从2个文件名路径和标签列表中解码。

 def parse_function(filename, label):
    image_string = tf.io.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string)
    image_resized = tf.image.resize(image_decoded, [4, 4])

    return image_resized, label

def dataset_maker(list_sample_paths, list_labels):

    filenames = tf.constant(list_sample_paths)
    labels = tf.constant(list_labels)

    dataset = tf.data.Dataset.from_tensor_slices({"image": filenames, "label": labels})
    dataset = dataset.map(parse_function)

training_dataset = dataset_maker(list_training_sample_paths, list_training_sample_labels)

但是我收到此错误消息

TypeError: tf__parse_function() missing 1 required positional argument: 'label'

在这种情况下,我需要使用dict理解吗? 非常感谢您对解决此问题的任何帮助。 谢谢!

在Srihari Humbarwadi答复后用元组解决此问题时添加此信息: 我想获得字典结构,因为我向Mnist投放了模型。

一个随机的Mnist示例具有以下结构:

{'image': <tf.Tensor: id=140275, shape=(28, 28, 1), dtype=uint8, numpy=array([[[  0],[  0],[  0]],dtype=uint8)>, 'label': <tf.Tensor: id=140276, shape=(), dtype=int64, numpy=6>}

2 个答案:

答案 0 :(得分:1)

您不需要以字典的形式传递文件名和标签的列表。您可以通过传递一个元组即使其工作。 (filenames, labels)。这是我使用的完整代码:

from glob import glob
import numpy as np
import tensorflow as tf

print('TensorFlow:', tf.__version__)

list_training_sample_paths = sorted(glob('images/*'))
# random integer labels
list_training_sample_labels = np.random.randint(low=0, high=5, size=[len(list_training_sample_paths)])

def parse_function(filename, label):
    image_string = tf.io.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string)
    image_resized = tf.image.resize(image_decoded, [4, 4])

    return image_resized, label

def dataset_maker(list_sample_paths, list_labels):

    filenames = tf.constant(list_sample_paths)
    labels = tf.constant(list_labels)

    dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
    dataset = dataset.map(parse_function)
    return dataset

training_dataset = dataset_maker(list_training_sample_paths, list_training_sample_labels)
tf.data.experimental.get_structure(training_dataset)

输出

TensorFlow: 2.2.0-rc2
(TensorSpec(shape=(4, 4, None), dtype=tf.float32, name=None), TensorSpec(shape=(),dtype=tf.int64, name=None))

答案 1 :(得分:0)

传递给映射函数的参数应与数据集对象的元素类型匹配。因此,如果您的元素是dict,则可以在dict中传递parse_function并返回dict
例如:

def parse_function(data):

    data_out = data
    filename = data['image']
    image_string = tf.io.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string)
    image_resized = tf.image.resize(image_decoded, [4, 4])
    data_out['image'] = image_resized

    return data_out

def dataset_maker(list_sample_paths, list_labels):

    filenames = tf.constant(list_sample_paths)
    labels = tf.constant(list_labels)

    dataset = tf.data.Dataset.from_tensor_slices({"image": filenames, "label": labels})
    dataset = dataset.map(parse_function)
    return dataset

training_dataset = dataset_maker(list_training_sample_paths, list_training_sample_labels)

或者,您可以传递@Srihari Humbarwadi建议的元组并返回字典。像这样:

def parse_function(filename, label):

    image_string = tf.io.read_file(filename)
    image_decoded = tf.image.decode_jpeg(image_string)
    image_resized = tf.image.resize(image_decoded, [4, 4])
    data_out = {'image': image_resized, 'label': label}

    return data_out

def dataset_maker(list_sample_paths, list_labels):

    filenames = tf.constant(list_sample_paths)
    labels = tf.constant(list_labels)

    dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
    dataset = dataset.map(parse_function)
    return dataset