我正在尝试从输入生成器函数创建批处理列表,但它不会产生我期望的列表。
def batch_generator(items, batch_size):
new = []
i = 0
for item in items:
new.append(item)
i += 1
print('new: ', new, i)
if i == batch_size:
print('i = batch')
i = 0
yield new
new = []
def _test_items_generator():
for i in range(10):
yield i
print(list(map(lambda x: len(x),
batch_generator(_test_items_generator(), 3))))
我试图让输出为[[0,1,2],[3,4,5],[6,7,8],[9]] yield似乎是发送batch_size而不是新列表中的信息。试着让我理解这些发电机是如何工作的!
答案 0 :(得分:1)
我认为问题在于你的最后一行:
print(list(map(lambda x: len(x),
batch_generator(_test_items_generator(), 3))))
batch_generator
会产生new
,其中包含一个列表。然后,您的map(lambda x: len(x)
将返回每个列表的len。然后打印map()
返回的长度列表。
以下是产生预期输出的代码:
def batch_generator(items, batch_size):
new = []
i = 0
for item in items:
new.append(item)
i += 1
print('new: ', new, i)
if i == batch_size:
print('i = batch')
i = 0
yield new
new = []
yield new # yield the last list even if it is smaller than batch size
def _test_items_generator():
for i in range(10):
yield i
print(list( batch_generator(_test_items_generator(), 3)))
答案 1 :(得分:0)
您的生成器工作正常。但在测试中,您将结果列表映射到其大小lambda x: len(x)
答案 2 :(得分:0)
另一种batch_generator
函数的方法:
def batch_generator(items, batch_size):
current_batch = []
for i, item in enumerate(items):
current_batch.append(item)
if len(current_batch) == batch_size:
yield current_batch
current_batch = []
if len(current_batch) < batch_size:
yield current_batch