如何在MXNET中创建组合迭代器?例如,给定一个记录(.rec)迭代器,如果我想更改与每个图像对应的标签,那么有两个选项: a)使用相同的数据(图像)和新标签创建一个新的rec迭代器。 b)使用原始rec迭代器和NDArray迭代器创建一个多迭代器,以便多迭代器从原始.rec迭代器和NDArray迭代器中的标签读取数据(图像)。 选项(a)很乏味。关于如何创建这样一个多迭代器的任何建议?
答案 0 :(得分:4)
class MultiIter(mx.io.DataIter):
def __init__(self, iter_list):
self.iters = iter_list
self.batch_size = 1000
def next(self):
batches = [i.next() for i in self.iters]
return mx.io.DataBatch(data=[t for t in batches[0].data]+ [t for t in batches[1].data], label= [t for t in batches[0].label] + [t for t in batches[1].label],pad=0)
def reset(self):
for i in self.iters:
i.reset()
@property
def provide_data(self):
return [t for t in self.iters[0].provide_data] + [t for t in self.iters[1].provide_data]
@property
def provide_label(self):
return [t for t in self.iters[0].provide_label] + [t for t in self.iters[1].provide_label]
train = MultiIter([train1,train2])
train1和train2可以是任何两个DataIter。特别是,train1可以是.rec迭代器,train2可以是NDArray迭代器。如果train1或train2中的任何一个是NDArray迭代器,则使用组合迭代器调用predict方法需要附加参数“pad = 0”。
MultiIter返回数据列表和两个迭代器组合的标签列表。如果只需要第一个迭代器中的数据和第二个迭代器中的标签,下面的代码就可以工作。
class MultiIter(mx.io.DataIter):
def __init__(self, iter_list):
self.iters = iter_list
self.batch_size = 1000
def next(self):
batches = [i.next() for i in self.iters]
return mx.io.DataBatch(data=[t for t in batches[0].data], label= [t for t in batches[1].label],pad=0)
def reset(self):
for i in self.iters:
i.reset()
@property
def provide_data(self):
return [t for t in self.iters[0].provide_data]
@property
def provide_label(self):
return [t for t in self.iters[1].provide_label]
train = MultiIter([train1,train2])