我继承了这段代码:
def _api_call(self, url, query=None, raw_query='', top=None, select=None, order_by=None, fetch_all=True,
full_response=False, timeout=10, retries=3, debugging=False):
if fetch_all and full_response:
# these parameters are not compatible
raise APIQueryException(message='fetch_all and full_response cannot be used together')
query_string = self._encode_query(query, top, select, order_by)
if query_string or raw_query:
query_string = '?' + query_string + raw_query
# get API signature headers for the new request from the api_auth singleton
headers = api_auth.generate_headers(path=url, method='GET')
service_end_point = 'https://%(portal)s.company.com%(url)s%(query_string)s' % {
'portal': self._portal,
'url': url,
'query_string': query_string,
}
retries_left = retries + 1 or 1
stop = False
kwargs = {'headers': headers, 'timeout': timeout, 'fetch_all': fetch_all}
accumulated_results = []
while not stop:
service_end_point, _tmp_res, stop = self._single_api_call(service_end_point, retries_left, stop, **kwargs)
accumulated_results.extend(_tmp_res)
return accumulated_results
def _single_api_call(self, service_end_point, retries_left, stop, debugging=True,**kwargs):
_res = []
headers = kwargs.pop('headers')
timeout = kwargs.pop('timeout')
fetch_all = kwargs.pop('fetch_all')
try:
while True:
if debugging:
print('Company Service API:', service_end_point)
result = requests.get(url=service_end_point, headers=headers, timeout=timeout)
break
except RequestException as e:
if retries_left > 0:
if debugging:
print('Company Service API EXCEPTION, retrying:', str(e))
retries_left -= 1
else:
raise
except requests.Timeout as e:
raise APITimeoutException(e, message='API request timeout')
except requests.ConnectionError as e:
raise APIRequestException(e, message='API request DNS error or connection refused')
except requests.TooManyRedirects as e:
raise APIRequestException(e, message='API request had too many redirects')
except requests.HTTPError as e:
raise APIRequestException(e, message='API request returned HTTP error status code')
if result.status_code == 400:
# Company often reports "Bad Request" if the query
# parameters are not acceptable
raise APIQueryException(message='API request failed, Company rejected query terms')
try:
parsed_result = json.loads(result.content)
except ValueError as e:
# an unknown failure mode
raise APIRequestException(
message='API request failed; no JSON returned; server said {}'.format(result.content))
if 'value' in parsed_result:
_res.extend(parsed_result['value'])
else:
pass
if '@odata.nextLink' in parsed_result and fetch_all:
service_end_point = parsed_result['@odata.nextLink']
else:
# no more pages
stop = True
return service_end_point, _res, stop
这很好用
call_1 = api_obj._api_call(url, *args, **kwargs)
len(call_1)
3492
然而,为了使用发电机,我试图重构它,但是我搞砸了什么。
我对while not stop
方法的_api_call
部分进行了以下更改:
while not stop:
try:
service_end_point, _tmp_res, stop = self._single_api_call(service_end_point, retries_left, stop, **kwargs)
accumulated_results.extend(_tmp_res)
if stop:
raise StopIteration
else:
yield _tmp_res
except StopIteration:
return accumulated_results
我看到每个调用都是计算的,但结果是:
call_2 = api_obj._api_call(url, *args, **kwargs)
len(call_2)
3
这三个项目中的每一个都是一个包含1000个项目的列表,因此我在单独的列表中总共有3000个项目,而不是原始实现中的3492项。
如何更改/重写此内容以实现这一目标?
答案 0 :(得分:3)
您一次尝试两种方式,yield
和return
。这是 legal ,但它并不意味着你可能想要的意思。
此外,您不需要提出StopIteration
来处理它并将其转换为return
,生成器协议将转变为StopIteration
}。只需return
并删除其中两个步骤(以及两个额外的机会来解决问题)。或者,在这种情况下,我们可以从while not stop:
循环的末尾开始,就像原始代码一样,并且不再使用return
,因为我们会从函数的末尾开始。
与此同时,您的旧代码将每个_tmp_res
添加到列表extend
,而非append
,其效果为"展平" list-if _tmp_res
是1000个项目的列表,extend
将1000个项目添加到列表的末尾。但是yield _tmp_res
只会产生1000个项目的子列表。您可能需要yield from
:
while not stop:
service_end_point, _tmp_res, stop = self._single_api_call(service_end_point, retries_left, stop, **kwargs)
yield from _tmp_res
如果您不了解yield from
的含义,那么它大致相当于(在这种情况下):
for element in _tmp_res:
yield element
一般来说,yield from
功能更强大,但我们在这里并不需要任何这种力量。它仍然会更有效率(虽然可能不足以产生差异),当然它更短更简单,一旦你绕过这个想法就更有意义了。但是如果您的代码需要在Python 2.7中运行,那么您就不会yield from
,因此您必须使用yield
上的循环。