我实现了一个迭代器类,如下所示:
import numpy as np
import time
class Data:
def __init__(self, filepath):
# Computationaly expensive
print("Computationally expensive")
time.sleep(10)
print("Done!")
def __iter__(self):
return self
def __next__(self):
return np.zeros((2,2)), np.zeros((2,2))
count = 0
for batch_x, batch_y in Data("hello.csv"):
print(batch_x, batch_y)
count = count + 1
if count > 5:
break
count = 0
for batch_x, batch_y in Data("hello.csv"):
print(batch_x, batch_y)
count = count + 1
if count > 5:
break
但是构造函数的计算量很大,并且可能会多次调用for循环。例如,在上面的代码中,构造函数被调用两次(每个for循环创建一个新的Data对象)。
如何分隔构造函数和迭代器?我希望有以下代码,其中构造函数只调用一次:
data = Data(filepath)
for batch_x, batch_y in data.get_iterator():
print(batch_x, batch_y)
for batch_x, batch_y in data.get_iterator():
print(batch_x, batch_y)
答案 0 :(得分:2)
您可以直接迭代可迭代对象,for..in
不需要任何其他内容:
data = Data(filepath)
for batch_x, batch_y in data:
print(batch_x, batch_y)
for batch_x, batch_y in data:
print(batch_x, batch_y)
也就是说,根据您实施__iter__()
的方式,这可能是错误的。
E.g:
class Data:
def __init__(self, filepath):
self._items = load_items(filepath)
self._i = 0
def __iter__(self): return self
def __next__(self):
if self._i >= len(self._items): # Or however you check if data is available
raise StopIteration
result = self._items[self._i]
self._i += 1
return result
因为那时你不能两次迭代同一个对象,因为self._i
仍然指向循环的结尾。
class Data:
def __init__(self, filepath):
self._items = load_items(filepath)
def __iter__(self):
self._i = 0
return self
def __next__(self):
if self._i >= len(self._items):
raise StopIteration
result = self._items[self._i]
self._i += 1
return result
这会在您每次重复迭代时重置索引,修复上述内容。如果您在同一个对象上嵌套迭代,这将无法工作。
要解决此问题,请将迭代状态保存在单独的迭代器对象中:
class Data:
class Iter:
def __init__(self, data):
self._data = data
self._i = 0
def __next__(self):
if self._i >= len(self._data._items): # check for available data
raise StopIteration
result = self._data._items[self._i]
self._i = self._i + 1
def __init__(self, filepath):
self._items = load_items(filepath)
def __iter__(self):
return self.Iter(self)
这是最灵活的方法,但如果你可以使用以下任何一种,那就不必要了。
yield
如果您使用Python的生成器,该语言将负责跟踪迭代状态,即使嵌套循环也应该正确执行:
class Data:
def __init__(self, filepath):
self._items= load_items(filepath)
def __iter__(self):
for it in self._items: # Or whatever is appropriate
yield return it
如果"计算成本高昂" part正在将所有数据加载到内存中,您可以直接使用缓存数据。
class Data:
def __init__(self, filepath):
self._items = load_items(filepath)
def __iter__(self):
return iter(self._items)
答案 1 :(得分:1)
不是创建Data
的新实例,而是创建第二个类IterData
,其中包含__init__
方法,该方法运行的过程不像实例化{{1}那样计算成本高。 }。然后,在Data
中创建classmethod
作为Data
的替代构造函数:
IterData