我遇到了一个错误,在该错误中,遍历数据的迭代器在返回末尾之前在集合的末尾提供了比预期短的余数。我需要所有迭代的大小都完全相同,因此我希望它删除剩余的部分。不幸的是,迭代器被称为另一个对象的参数,因此我无法控制何时调用__next__()
。想到的解决方案是创建一个自迭代器继承的自定义类,并且仅重新定义__next__()
。
我正在尝试如下操作:
class CleanIterator(NumpyArrayIterator):
def __init__(self, _super):
super=_super
def __next__(self, *args, **kwargs):
return self.next(*args, **kwargs)
def next(self, *args, **kwargs):
while True:
data= super.next(*args, **kwargs)
# reject short data snippets
if data[0].shape[0] == self.batch_size:
return data
# .flow return a NumpyArrayIterator
data_generator= ImageDataGenerator().flow(
valid_data, valid_labels,
batch_size= 100)
data_generator= CleanIterator(data_generator)
这设法从多个继承级别继承函数,但似乎只从NumpyArrayIterator继承变量。
结果我得到了这样的错误:
venv/lib/python3.6/site-packages/keras_preprocessing/image/iterator.py", line 68, in __len__
return (self.n + self.batch_size - 1) // self.batch_size # round up
AttributeError: 'CleanIterator' object has no attribute 'n'
其中NumpyArrayIterator继承自Iterator。迭代器具有变量self.n
和函数__len__()
。
我试图寻找正确的语法,但是我能找到的每个示例都是从头开始构建父类,而不是从现有实例开始。
所以我想问题是:如何从父类的旧实例创建子类的新实例?
答案 0 :(得分:0)
我不确定这是否是您想要的。对于您的要求我很困惑,这是否能回答您的问题?
class first:
def test_function(self):
print('it works')
instance = first()
class secondary(instance.__class__):
"""This class adds functionality to the already existent instance of the First class"""
def __init__(self):
super().__init__()
def added_function(self):
pass
如您所见,通过使用实例的__class__
属性,我们可以访问其类。
答案 1 :(得分:0)
如果您要过滤的内容,就不能只使用itertools.filterfalse
吗?
import itertools
...
data_generator = itertools.filterfalse(
(lambda x:x[0].shape[0] == x.batch_size),
data_generator)
答案 2 :(得分:0)
def clean_itterator(fn, index_array, generator):
while True:
data= fn(index_array)
# reject under-sized batches
if data[0].shape[0] == generator.batch_size:
return data
return None
data_generator= ImageDataGenerator().flow(
valid_data, valid_labels,
batch_size= 100)
original_fn= data_generator._get_batches_of_transformed_samples
data_generator._get_batches_of_transformed_samples= \
lambda index_array: clean_itterator(
original_fn,
index_array,
data_generator)
这是丑陋的罪恶,它无法按照我想要的方式运行,但它可以运行