如何在仅将批处理加载到张量流的新数据集API中时避免加载整个数据?

时间:2019-04-30 21:37:41

标签: python tensorflow tensorflow-datasets

在使用内置的Tensorflow优化器训练我的一种机器学习模型时,我只想加载一批数据样本来训练参数。但是,使用新的数据集API,我只知道首先加载整个训练数据集并一次读取单个批次的方法。

以下是我当前的代码:

X = np.random.sample((10000,2))
dataset = tf.data.Dataset.from_tensor_slices(X).shuffle(buffer_size=100).repeat().batch(10)

iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()

training_iterations = 100

with tf.Session() as sess:

      for iter in range(training_iterations):

             traninig_batch = sess.run(next_batch)
             # Do something with the training_batch here

因此,我想知道是否有可能使用数据集API快速获取X,而不是加载整个training_batch。尽管对于X的这种简单分布来说,这可能是微不足道的,但是在我的问题中,X来自某种复杂的概率分布,为此,我需要运行一个函数来从中获取样本。

谢谢!

0 个答案:

没有答案