我正在尝试编写一个简单的异步数据批量生成器,但是在理解如何从异步for循环中产生麻烦时遇到了麻烦。在这里,我写了一个简单的课程来说明我的想法:
import asyncio
from typing import List
class AsyncSimpleIterator:
def __init__(self, data: List[str], batch_size=None):
self.data = data
self.batch_size = batch_size
self.doc2index = self.get_doc_ids()
def get_doc_ids(self):
return list(range(len(self.data)))
async def get_batch_data(self, doc_ids):
print("get_batch_data() running")
page = [self.data[j] for j in doc_ids]
return page
async def get_docs(self, batch_size):
print("get_docs() running")
_batch_size = self.batch_size or batch_size
batches = [self.doc2index[i:i + _batch_size] for i in
range(0, len(self.doc2index), _batch_size)]
for _, doc_ids in enumerate(batches):
docs = await self.get_batch_data(doc_ids)
yield docs, doc_ids
async def main(self):
print("main() running")
async for res in self.get_docs(batch_size=2):
print(res) # how to yield instead of print?
def gen_batches(self):
# how to get results of self.main() here?
loop = asyncio.get_event_loop()
loop.run_until_complete(self.main())
loop.close()
DATA = ["Hello, world!"] * 4
iterator = AsyncSimpleIterator(DATA)
iterator.gen_batches()
所以,我的问题是,如何从main()
产生结果以将其收集到gen_batches()
中?
当我在main()
中打印结果时,我得到以下输出:
main() running
get_docs() running
get_batch_data() running
(['Hello, world!', 'Hello, world!'], [0, 1])
get_batch_data() running
(['Hello, world!', 'Hello, world!'], [2, 3])
答案 0 :(得分:1)
我试图编写一个简单的异步数据批处理生成器,但在理解如何从异步循环中产生麻烦
从async for
产生的效果类似于常规产量,但它必须由async for
或同等产品收集。例如,yield
中的get_docs
使其成为异步生成器。如果您将print(res)
替换为yield res
中的main()
,它也会使main()
成为异步生成器。
中收集所有结果
main()
中的生成器应在gen_batches()
耗尽,因此我可以在gen_batches()
要收集异步生成器生成的值(例如main()
并将print(res)
替换为yield res
),您可以使用辅助协程:
def gen_batches(self):
loop = asyncio.get_event_loop()
async def collect():
return [item async for item in self.main()]
items = loop.run_until_complete(collect())
loop.close()
return items
collect()
助手使用PEP 530异步理解,可以将其视为更明确的语法糖:
async def collect():
l = []
async for item in self.main():
l.append(item)
return l
答案 1 :(得分:0)
基于@ user4815162342回答原始问题的工作解决方案:
import asyncio
from typing import List
class AsyncSimpleIterator:
def __init__(self, data: List[str], batch_size=None):
self.data = data
self.batch_size = batch_size
self.doc2index = self.get_doc_ids()
def get_doc_ids(self):
return list(range(len(self.data)))
async def get_batch_data(self, doc_ids):
print("get_batch_data() running")
page = [self.data[j] for j in doc_ids]
return page
async def get_docs(self, batch_size):
print("get_docs() running")
_batch_size = self.batch_size or batch_size
batches = [self.doc2index[i:i + _batch_size] for i in
range(0, len(self.doc2index), _batch_size)]
for _, doc_ids in enumerate(batches):
docs = await self.get_batch_data(doc_ids)
yield docs, doc_ids
def gen_batches(self):
loop = asyncio.get_event_loop()
async def collect():
return [j async for j in self.get_docs(batch_size=2)]
items = loop.run_until_complete(collect())
loop.close()
return items
DATA = ["Hello, world!"] * 4
iterator = AsyncSimpleIterator(DATA)
result = iterator.gen_batches()
print(result)