我正在研究tutorial on GANs by Google。在此笔记本中,他们定义了 input_fn ,其中使用tfds加载了MNIST数据集。我已经生成了自己的数据集,并将其存储在numpy数组中(形状:4500、512、512)。
我无法理解 input_fn 的工作方式以及修改方式,以便可以从gdrive输入训练数据,而不是从tf数据集下载。我注意到,在调用 gan_estimator.train 时,在训练时也会使用 input_fn 。谁能解释该功能的工作原理?
答案 0 :(得分:0)
函数 input_fn 使用下面一行中的TensorFlow数据集来加载MNIST。
tfds.load('mnist', split=split)
.map(_preprocess)
.cache()
.repeat()
您需要了解TensorFlow数据集如何按照需要处理的结构来制作自己的数据集。
您可以获得更多信息here。