我是pytorch用户,我习惯于pytorch中的data.dataset和data.dataloader API。我正在尝试使用tensorflow 2.0构建相同的模型,并且我想知道pytorch中是否有一个与这些API相似的API。
如果没有这样的api,你们中的任何人都可以告诉我人们通常如何在tensorflow中实现数据加载部分吗?我使用过tensorflow 1,但从未有过使用数据集api的经验。我之前已经硬编码过。我希望有类似覆盖 getitem 的操作,仅将索引作为输入。
非常感谢。
答案 0 :(得分:2)
在使用tf.data
API时,通常也将使用map
函数。
在PyTorch中,您的__getItem__
调用基本上是从__init__
中给出的数据结构中获取元素,并在必要时进行转换。
在TF2.0中,您可以通过使用Dataset
函数之一初始化Dataset.from_...
来执行相同的操作(请参阅from_generator
,from_tensor_slices
,from_tensors
) ;这实际上是PyTorch __init__
的{{1}}部分。然后,您可以调用Dataset
来进行在map
中进行的按元素操作。
Tensorflow数据集几乎是花哨的迭代器,因此从设计上来说,您不必使用索引来访问它们的元素,而是遍历它们。
__getItem__
上的guide非常有用,并提供了各种各样的示例。
答案 1 :(得分:1)
我对Pytorch并不熟悉,但是Tensorflow实现了Keras API,该API的Sequence类为:
适合于一系列数据(例如数据集)的基础对象
https://www.tensorflow.org/api_docs/python/tf/keras/utils/Sequence
此类包含索引的getitem。