如何在tensorflow中自定义数据集中正确实现next_batch?

时间:2018-04-16 13:31:45

标签: tensorflow machine-learning deep-learning

我正在寻找在tensorflow中实现next_batch的正确方法。我的训练数据为train_X=10000x50,其中10000是样本数,50是特征向量的数量,train_Y=10000x1。我使用的批量大小为128.这是我在培训期间获得培训批次的功能

def next_batch(num, data, labels):
    '''
    Return a total of `num` random samples and labels.
    '''
    idx = np.arange(0 , data.shape[0])
    np.random.shuffle(idx)
    idx = idx[:num]
    data_shuffle = [data[ i,:] for i in idx]
    labels_shuffle = [labels[ i] for i in idx]
    return np.asarray(data_shuffle), np.asarray(labels_shuffle)

n_samples = 10000
batch_size =128

with tf.Session() as sess:
sess.run(init)
n_batches = int(n_samples / batch_size)
for i in range(n_epochs):
    for j in range(n_batches):
        X_batch, Y_batch = next_batch(batch_size,train_X,train_Y)

通过上述功能,我发现每个批次都会调用shuffle函数,这不是想要的行为。我们必须扫描训练数据中的所有混洗元素,然后再次为下一个新纪元进行混洗。我对吗?如何在tensorflow中修复它?感谢

1 个答案:

答案 0 :(得分:1)

解决方案是使用生成器生成批次,以便跟踪采样状态(混洗索引列表和此列表中的当前位置)。

在下面找到您可以构建的解决方案。

def next_batch(num, data, labels):
    '''
    Return a total of maximum `num` random samples and labels.
    NOTE: The last batch will be of size len(data) % num
    '''
    num_el = data.shape[0]
    while True: # or whatever condition you may have
        idx = np.arange(0 , num_el)
        np.random.shuffle(idx)
        current_idx = 0
        while current_idx < num_el:
            batch_idx = idx[current_idx:current_idx+num]
            current_idx += num
            data_shuffle = [data[ i,:] for i in batch_idx]
            labels_shuffle = [labels[ i] for i in batch_idx]
            yield np.asarray(data_shuffle), np.asarray(labels_shuffle)

n_samples = 10000
batch_size =128

with tf.Session() as sess:
    sess.run(init)
    n_batches = int(n_samples / batch_size)
    next_batch_gen = next_batch(batch_size, train_X, train_Y)
    for i in range(n_epochs):
        for j in range(n_batches):
            X_batch, Y_batch = next(next_batch_gen)
            print(Y_batch)