从tf.data.Dataset.map()返回数据集会导致'TensorSliceDataset'对象没有属性'get_shape'错误

时间:2018-06-12 04:28:42

标签: tensorflow machine-learning deep-learning tensorflow-datasets

我正在使用数据集API来创建输入管道。我在类似于以下的模式中使用tf.data.Dataset.map()方法:

def mapped_fn(_):
    X = tf.random_uniform([3,3])
    y = tf.random_uniform([3,1])
    dataset = tf.data.Dataset.from_tensor_slices((X,y))
    return dataset

with tf.Session() as sess:
    first = tf.random_uniform([1,2])         
    unimportant_dataset = tf.data.Dataset.from_tensors(first)
    dataset = unimportant_dataset.map(mapped_fn)
    sess.run(dataset)

我收到以下错误:AttributeError: 'TensorSliceDataset' object has no attribute 'get_shape'

整体上下文是mapped_fn从.tfrecords文件反序列化一个示例protobuf(在本例中由unimportant_dataset表示),重新整形特征向量(X),并且需要返回一个数据集,该数据集包含由新特征向量(在这种情况下为形状(3,))中的切片定义的元素。返回ZipDataset时,我遇到了类似的错误。提前谢谢!

2 个答案:

答案 0 :(得分:3)

DomJack's answerDataset.map()的签名绝对正确:它希望传递的mapped_fn的返回值为一个或多个张量(或稀疏张量)。

如果你有一个返回Dataset的函数,你可以使用Dataset.flat_map()将所有返回的数据集展平并连接成一个数据集,如下所示:

def mapped_fn(_):
    X = tf.random_uniform([3,3])
    y = tf.random_uniform([3,1])
    dataset = tf.data.Dataset.from_tensor_slices((X,y))
    return dataset

# Generate 100 dummy elements.
unimportant_dataset = tf.data.Dataset.range(100)

# Convert each dummy element into a dataset of 3 nested elements, and concatenate them.
dataset = unimportant_dataset.flat_map(mapped_fn)

答案 1 :(得分:1)

传递给map_fn的{​​{1}}应该从调用数据集中获取单个示例的张量,并返回返回数据集的张量。

e.g。

tf.data.Dataset.map