如何在Python中定义批处理生成器?

时间:2019-01-18 07:25:36

标签: python-3.x

我的目录中有大约一百万张图片。我想创建一个batch_generator,以便可以训练CNN,因为我无法一次将所有这些图像保存在内存中。

所以,我写了一个生成器函数来做到这一点-

def batch_generator(image_paths, batch_size, isTraining):
    while True:
        batch_imgs = []
        batch_labels = []

        type_dir = 'train' if isTraining else 'test'

        for i in range(len(image_paths)):
            print(i)
            print(os.path.join(data_dir_base, type_dir, image_paths[i]))
            img = cv2.imread(os.path.join(data_dir_base, type_dir, image_paths[i]), 0)
            img  = np.divide(img, 255)
            img = img.reshape(28, 28, 1)
            batch_imgs.append(img)
            label = image_paths[i].split('_')[1].split('.')[0]
            batch_labels.append(label)
            if len(batch_imgs) == batch_size:
                yield (np.asarray(batch_imgs), np.asarray(batch_labels))
                batch_imgs = []
        if batch_imgs:
            yield batch_imgs

当我调用此语句时-

index = next(batch_generator(train_dataset, 10, True))

它正在打印相同的索引值和路径,因此,它在每次调用next()时都返回相同的批处理。 我该如何解决?

我将此问题作为代码的参考-how to split an iterable in constant-size chunks

4 个答案:

答案 0 :(得分:0)

生成器函数本身不是生成器,而是“生成器工厂”-每次调用batch_generator(...)时,它都会返回一个全新的生成器,准备再次启动。 IOW,您要:

gen = batch_generator(...)
for batch in gen:       
    do_something_with(batch)

也:

1 /您编写生成器函数的方式将创建一个无限生成器-外部while循环将永远重复-可能与您期望的不一样(我想最好警告您)。

2 /您的代码中存在两个逻辑错误:首先,您没有重置batch_labels列表,然后在最后一个yield上您只产生了batch_imgs,这不是与内部yield保持一致。 FWIW,而不是维护两个列表(一个用于图像,另一个用于标签),您可能最好使用一个(img, label)元组的单个列表。

最后一点要注意:您不需要使用range(len(lst))来迭代列表-Python的for循环是foreach类型的,它直接在iterable的项目,即:

for path image_paths:
    print(path)

工作原理相同,可读性更高,并且速度更快...

答案 1 :(得分:0)

在我看来,您正在努力实现以下目标:

def batch_generator(image_paths, batch_size, isTraining):
    your_code_here

呼叫生成器-而不是拥有的:

index = next(batch_generator(train_dataset, 10, True))

您可以尝试:

index = iter(batch_generator(train_dataset, 10, True))
index.__next__()

答案 2 :(得分:0)

# batch generator
def get_batches(dataset, batch_size):
    X, Y = dataset
    n_samples = X.shape[0]

    # Shuffle at the start of epoch
    indices = np.arange(n_samples)
    np.random.shuffle(indices)

    for start in range(0, n_samples, batch_size):
        end = min(start + batch_size, n_samples)

        batch_idx = indices[start:end]

        yield X[batch_idx], Y[batch_idx]

答案 3 :(得分:0)

我制作了自己的生成器,支持限制、批处理或简单的第 1 步迭代:

def gen(batch = None, limit = None):
    ret = []
    for i in range(1, 11): # put your data reading here and i counter (i += 1) under for
        if batch:
            ret.append(i)
            if limit and i == limit:
                if len(ret):            
                    yield ret
                return
            if len(ret) == batch:
                yield ret
                ret = []
        else:
            if limit and i > limit:
                break
            yield i
    if batch and len(ret): # yield the rest of the list
        yield ret
            
g = gen(batch=5, limit=8) # batches with limit
#g = gen(batch=5) # batches
#g = gen(limit=5) # step 1 with limit
#g = gen() # step 1 with limit
for i in g:
    print(i)