在使用内置的Tensorflow优化器训练我的一种机器学习模型时,我只想加载一批数据样本来训练参数。但是,使用新的数据集API,我只知道首先加载整个训练数据集并一次读取单个批次的方法。
以下是我当前的代码:
X = np.random.sample((10000,2))
dataset = tf.data.Dataset.from_tensor_slices(X).shuffle(buffer_size=100).repeat().batch(10)
iterator = dataset.make_one_shot_iterator()
next_batch = iterator.get_next()
training_iterations = 100
with tf.Session() as sess:
for iter in range(training_iterations):
traninig_batch = sess.run(next_batch)
# Do something with the training_batch here
因此,我想知道是否有可能使用数据集API快速获取X
,而不是加载整个training_batch
。尽管对于X
的这种简单分布来说,这可能是微不足道的,但是在我的问题中,X
来自某种复杂的概率分布,为此,我需要运行一个函数来从中获取样本。
谢谢!