我正在寻找在tensorflow中实现next_batch
的正确方法。我的训练数据为train_X=10000x50
,其中10000是样本数,50是特征向量的数量,train_Y=10000x1
。我使用的批量大小为128.这是我在培训期间获得培训批次的功能
def next_batch(num, data, labels):
'''
Return a total of `num` random samples and labels.
'''
idx = np.arange(0 , data.shape[0])
np.random.shuffle(idx)
idx = idx[:num]
data_shuffle = [data[ i,:] for i in idx]
labels_shuffle = [labels[ i] for i in idx]
return np.asarray(data_shuffle), np.asarray(labels_shuffle)
n_samples = 10000
batch_size =128
with tf.Session() as sess:
sess.run(init)
n_batches = int(n_samples / batch_size)
for i in range(n_epochs):
for j in range(n_batches):
X_batch, Y_batch = next_batch(batch_size,train_X,train_Y)
通过上述功能,我发现每个批次都会调用shuffle
函数,这不是想要的行为。我们必须扫描训练数据中的所有混洗元素,然后再次为下一个新纪元进行混洗。我对吗?如何在tensorflow中修复它?感谢
答案 0 :(得分:1)
解决方案是使用生成器生成批次,以便跟踪采样状态(混洗索引列表和此列表中的当前位置)。
在下面找到您可以构建的解决方案。
def next_batch(num, data, labels):
'''
Return a total of maximum `num` random samples and labels.
NOTE: The last batch will be of size len(data) % num
'''
num_el = data.shape[0]
while True: # or whatever condition you may have
idx = np.arange(0 , num_el)
np.random.shuffle(idx)
current_idx = 0
while current_idx < num_el:
batch_idx = idx[current_idx:current_idx+num]
current_idx += num
data_shuffle = [data[ i,:] for i in batch_idx]
labels_shuffle = [labels[ i] for i in batch_idx]
yield np.asarray(data_shuffle), np.asarray(labels_shuffle)
n_samples = 10000
batch_size =128
with tf.Session() as sess:
sess.run(init)
n_batches = int(n_samples / batch_size)
next_batch_gen = next_batch(batch_size, train_X, train_Y)
for i in range(n_epochs):
for j in range(n_batches):
X_batch, Y_batch = next(next_batch_gen)
print(Y_batch)