Tensorflow 2.0数据集和数据加载器

时间:2019-10-22 13:52:14

标签: tensorflow pytorch tensorflow2.0

我是pytorch用户,我习惯于pytorch中的data.dataset和data.dataloader API。我正在尝试使用tensorflow 2.0构建相同的模型,并且我想知道pytorch中是否有一个与这些API相似的API。

如果没有这样的api,你们中的任何人都可以告诉我人们通常如何在tensorflow中实现数据加载部分吗?我使用过tensorflow 1,但从未有过使用数据集api的经验。我之前已经硬编码过。我希望有类似覆盖 getitem 的操作,仅将索引作为输入。

非常感谢。

2 个答案:

答案 0 :(得分:2)

在使用tf.data API时,通常也将使用map函数。

在PyTorch中,您的__getItem__调用基本上是从__init__中给出的数据结构中获取元素,并在必要时进行转换。

在TF2.0中,您可以通过使用Dataset函数之一初始化Dataset.from_...来执行相同的操作(请参阅from_generatorfrom_tensor_slicesfrom_tensors) ;这实际上是PyTorch __init__的{​​{1}}部分。然后,您可以调用Dataset来进行在map中进行的按元素操作。

Tensorflow数据集几乎是花哨的迭代器,因此从设计上来说,您不必使用索引来访问它们的元素,而是遍历它们。

__getItem__上的guide非常有用,并提供了各种各样的示例。

答案 1 :(得分:1)

我对Pytorch并不熟悉,但是Tensorflow实现了Keras API,该API的Sequence类为:

  

适合于一系列数据(例如数据集)的基础对象

https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence

此类包含索引的getitem。