类方法返回迭代器

时间:2018-03-08 17:18:15

标签: python iterator

我实现了一个迭代器类,如下所示:

import numpy as np
import time


class Data:

    def __init__(self, filepath):
        # Computationaly expensive
        print("Computationally expensive")
        time.sleep(10)
        print("Done!")

    def __iter__(self):
        return self

    def __next__(self):
        return np.zeros((2,2)), np.zeros((2,2))


count = 0
for batch_x, batch_y in Data("hello.csv"):
    print(batch_x, batch_y)
    count = count + 1

    if count > 5:
        break


count = 0
for batch_x, batch_y in Data("hello.csv"):
    print(batch_x, batch_y)
    count = count + 1

    if count > 5:
        break

但是构造函数的计算量很大,并且可能会多次调用for循环。例如,在上面的代码中,构造函数被调用两次(每个for循环创建一个新的Data对象)。

如何分隔构造函数和迭代器?我希望有以下代码,其中构造函数只调用一次:

data = Data(filepath)

for batch_x, batch_y in data.get_iterator():
    print(batch_x, batch_y)

for batch_x, batch_y in data.get_iterator():
    print(batch_x, batch_y)

2 个答案:

答案 0 :(得分:2)

您可以直接迭代可迭代对象,for..in不需要任何其他内容:

data = Data(filepath)

for batch_x, batch_y in data:
    print(batch_x, batch_y)

for batch_x, batch_y in data:
    print(batch_x, batch_y)

也就是说,根据您实施__iter__()的方式,这可能是错误的。

E.g:

class Data:
    def __init__(self, filepath):
        self._items = load_items(filepath)
        self._i = 0
    def __iter__(self): return self
    def __next__(self):
        if self._i >= len(self._items): # Or however you check if data is available
            raise StopIteration
        result = self._items[self._i]
        self._i += 1
        return result

因为那时你不能两次迭代同一个对象,因为self._i仍然指向循环的结尾。

好肥胖型

class Data:
    def __init__(self, filepath):
        self._items = load_items(filepath)
    def __iter__(self):
        self._i = 0
        return self
    def __next__(self):
        if self._i >= len(self._items):
            raise StopIteration
        result = self._items[self._i]
        self._i += 1
        return result

这会在您每次重复迭代时重置索引,修复上述内容。如果您在同一个对象上嵌套迭代,这将无法工作。

更好

要解决此问题,请将迭代状态保存在单独的迭代器对象中:

class Data:
    class Iter:
        def __init__(self, data):
            self._data = data
            self._i = 0
        def __next__(self):
            if self._i >= len(self._data._items): # check for available data
                raise StopIteration
            result = self._data._items[self._i]
            self._i = self._i + 1
    def __init__(self, filepath):
        self._items = load_items(filepath)
    def __iter__(self): 
        return self.Iter(self)

这是最灵活的方法,但如果你可以使用以下任何一种,那就不必要了。

简单,使用yield

如果您使用Python的生成器,该语言将负责跟踪迭代状态,即使嵌套循环也应该正确执行:

class Data:
    def __init__(self, filepath):
        self._items= load_items(filepath)
    def __iter__(self): 
        for it in self._items: # Or whatever is appropriate
            yield return it

简单,传递给基础可迭代

如果"计算成本高昂" part正在将所有数据加载到内存中,您可以直接使用缓存数据。

class Data:
    def __init__(self, filepath):
        self._items = load_items(filepath)
    def __iter__(self): 
        return iter(self._items)

答案 1 :(得分:1)

不是创建Data的新实例,而是创建第二个类IterData,其中包含__init__方法,该方法运行的过程不像实例化{{1}那样计算成本高。 }。然后,在Data中创建classmethod作为Data的替代构造函数:

IterData