我的目录中有大约一百万张图片。我想创建一个batch_generator,以便可以训练CNN,因为我无法一次将所有这些图像保存在内存中。
所以,我写了一个生成器函数来做到这一点-
def batch_generator(image_paths, batch_size, isTraining):
while True:
batch_imgs = []
batch_labels = []
type_dir = 'train' if isTraining else 'test'
for i in range(len(image_paths)):
print(i)
print(os.path.join(data_dir_base, type_dir, image_paths[i]))
img = cv2.imread(os.path.join(data_dir_base, type_dir, image_paths[i]), 0)
img = np.divide(img, 255)
img = img.reshape(28, 28, 1)
batch_imgs.append(img)
label = image_paths[i].split('_')[1].split('.')[0]
batch_labels.append(label)
if len(batch_imgs) == batch_size:
yield (np.asarray(batch_imgs), np.asarray(batch_labels))
batch_imgs = []
if batch_imgs:
yield batch_imgs
当我调用此语句时-
index = next(batch_generator(train_dataset, 10, True))
它正在打印相同的索引值和路径,因此,它在每次调用next()时都返回相同的批处理。 我该如何解决?
我将此问题作为代码的参考-how to split an iterable in constant-size chunks
答案 0 :(得分:0)
生成器函数本身不是生成器,而是“生成器工厂”-每次调用batch_generator(...)
时,它都会返回一个全新的生成器,准备再次启动。 IOW,您要:
gen = batch_generator(...)
for batch in gen:
do_something_with(batch)
也:
1 /您编写生成器函数的方式将创建一个无限生成器-外部while循环将永远重复-可能与您期望的不一样(我想最好警告您)。
2 /您的代码中存在两个逻辑错误:首先,您没有重置batch_labels
列表,然后在最后一个yield
上您只产生了batch_imgs
,这不是与内部yield
保持一致。 FWIW,而不是维护两个列表(一个用于图像,另一个用于标签),您可能最好使用一个(img, label)
元组的单个列表。
最后一点要注意:您不需要使用range(len(lst))
来迭代列表-Python的for
循环是foreach
类型的,它直接在iterable的项目,即:
for path image_paths:
print(path)
工作原理相同,可读性更高,并且速度更快...
答案 1 :(得分:0)
在我看来,您正在努力实现以下目标:
def batch_generator(image_paths, batch_size, isTraining):
your_code_here
呼叫生成器-而不是拥有的:
index = next(batch_generator(train_dataset, 10, True))
您可以尝试:
index = iter(batch_generator(train_dataset, 10, True))
index.__next__()
答案 2 :(得分:0)
# batch generator
def get_batches(dataset, batch_size):
X, Y = dataset
n_samples = X.shape[0]
# Shuffle at the start of epoch
indices = np.arange(n_samples)
np.random.shuffle(indices)
for start in range(0, n_samples, batch_size):
end = min(start + batch_size, n_samples)
batch_idx = indices[start:end]
yield X[batch_idx], Y[batch_idx]
答案 3 :(得分:0)
我制作了自己的生成器,支持限制、批处理或简单的第 1 步迭代:
def gen(batch = None, limit = None):
ret = []
for i in range(1, 11): # put your data reading here and i counter (i += 1) under for
if batch:
ret.append(i)
if limit and i == limit:
if len(ret):
yield ret
return
if len(ret) == batch:
yield ret
ret = []
else:
if limit and i > limit:
break
yield i
if batch and len(ret): # yield the rest of the list
yield ret
g = gen(batch=5, limit=8) # batches with limit
#g = gen(batch=5) # batches
#g = gen(limit=5) # step 1 with limit
#g = gen() # step 1 with limit
for i in g:
print(i)