MXNET多重迭代器(将.rec迭代器与NDArray迭代器结合使用)

时间:2017-08-22 01:17:10

标签: mxnet

如何在MXNET中创建组合迭代器?例如,给定一个记录(.rec)迭代器,如果我想更改与每个图像对应的标签,那么有两个选项: a)使用相同的数据(图像)和新标签创建一个新的rec迭代器。 b)使用原始rec迭代器和NDArray迭代器创建一个多迭代器,以便多迭代器从原始.rec迭代器和NDArray迭代器中的标签读取数据(图像)。 选项(a)很乏味。关于如何创建这样一个多迭代器的任何建议?

1 个答案:

答案 0 :(得分:4)

class MultiIter(mx.io.DataIter):  
    def __init__(self, iter_list):  
        self.iters = iter_list   
        self.batch_size = 1000  
    def next(self):  
        batches = [i.next() for i in self.iters]  
        return mx.io.DataBatch(data=[t for t in batches[0].data]+ [t for t in batches[1].data], label= [t for t in batches[0].label] + [t for t in batches[1].label],pad=0)  
    def reset(self):  
        for i in self.iters:  
            i.reset()  
    @property  
    def provide_data(self):  
        return [t for t in self.iters[0].provide_data] + [t for t in self.iters[1].provide_data] 
    @property  
    def provide_label(self):  
        return [t for t in self.iters[0].provide_label] + [t for t in self.iters[1].provide_label]

train = MultiIter([train1,train2])

train1和train2可以是任何两个DataIter。特别是,train1可以是.rec迭代器,train2可以是NDArray迭代器。如果train1或train2中的任何一个是NDArray迭代器,则使用组合迭代器调用predict方法需要附加参数“pad = 0”。

MultiIter返回数据列表和两个迭代器组合的标签列表。如果只需要第一个迭代器中的数据和第二个迭代器中的标签,下面的代码就可以工作。

class MultiIter(mx.io.DataIter):  
    def __init__(self, iter_list):  
        self.iters = iter_list   
        self.batch_size = 1000  
    def next(self):  
        batches = [i.next() for i in self.iters]  
        return mx.io.DataBatch(data=[t for t in batches[0].data], label= [t for t in batches[1].label],pad=0)  
    def reset(self):  
        for i in self.iters:  
            i.reset()  
    @property  
    def provide_data(self):  
        return [t for t in self.iters[0].provide_data] 
    @property  
    def provide_label(self):  
        return [t for t in self.iters[1].provide_label] 

train = MultiIter([train1,train2])