在mxnet中使用我自己的python数据迭代器时出错

时间:2016-12-25 17:30:03

标签: python mxnet

我正在尝试创建自己的数据迭代器以与mxnet一起使用。当我运行它时,我收到错误:

Traceback (most recent call last):
File "train.py", line 24, in <module>
batch_end_callback = mx.callback.Speedometer(batch_size, 1) # output progress for each 200 data batches
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.7.0-py2.7.egg/mxnet/model.py", line 811, in fit
sym_gen=self.sym_gen)
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.7.0-py2.7.egg/mxnet/model.py", line 236, in _train_multi_device
executor_manager.load_data_batch(data_batch)
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.7.0-py2.7.egg/mxnet/executor_manager.py", line 410, in load_data_batch
self.curr_execgrp.load_data_batch(data_batch)
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.7.0-py2.7.egg/mxnet/executor_manager.py", line 257, in load_data_batch
_load_data(data_batch, self.data_arrays)
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.7.0-py2.7.egg/mxnet/executor_manager.py", line 93, in _load_data
_load_general(batch.data, targets)
File "/usr/local/lib/python2.7/dist-packages/mxnet-0.7.0-py2.7.egg/mxnet/executor_manager.py", line 89, in _load_general
d_src[slice_idx].copyto(d_dst)
AttributeError: 'numpy.ndarray' object has no attribute 'copy'

我认为这与我返回数据的方式有关。请参阅下面的数据迭代器代码:

from mxnet.io import DataIter, DataDesc
import csv
from random import shuffle
import numpy as np
from cv2 import imread, resize

class MyData(DataIter):
    def __init__(self, root_dir, flist_name, batch_size, size=(256,256), shuffle=True):
        super(MyData, self).__init__()
        self.batch_size = batch_size
        self.root_dir = root_dir
        self.flist_name = flist_name
        self.size = size
        self.shuffle = shuffle

        self.data = []
        with open(flist_name, 'rb') as csvfile:
            csvreader = csv.reader(csvfile)
            for row in csvreader:
                self.data.append(row)
        self.num_data = len(self.data)
        self.provide_data = [DataDesc('data', (self.batch_size, 6, self.size[0], self.size[1]), np.float32)]
        self.provide_label = [DataDesc('Pa_label', (self.batch_size, 1), np.float32)]
        self.reset()

    def reset(self):
        """Reset the iterator. """
        self.cursor = 0
        if self.shuffle:
            shuffle(self.data)

    def iter_next(self):
        """Iterate to next batch.
        Returns
        -------
        has_next : boolean
            Whether the move is successful.
        """
        self.cursor += self.batch_size
        success = self.cursor < self.num_data
        return success

    def getdata(self):
        """Get data of current batch.
        Returns
        -------
        data : NDArray
            The data of current batch.
        """
        datalist = self.data[self.cursor:self.cursor+self.batch_size]
        ret = np.ndarray(shape=(0,6,self.size[0],self.size[1]), dtype=np.float32)
        for data_row in datalist:
            img1 = resize(imread(data_row[0]), self.size)
            img2 = resize(imread(data_row[1]), self.size)
            img1 = np.rollaxis(img1, 2)
            img2 = np.rollaxis(img2, 2)
            img = np.concatenate((img1, img2), 0)
            imge = np.expand_dims(img,0)
            ret = np.append(ret, imge, 0)

        print ret.shape
        pad = self.batch_size - ret.shape[0]
        if pad > 0:
            ret = np.append(ret, np.zeros((pad, 6, self.size[0], self.size[1])), 0)
        return ret

    def getlabel(self):
        """Get label of current batch.
        Returns
        -------
        label : NDArray
            The label of current batch.
        """
        datalist = self.data[self.cursor:self.cursor+self.batch_size]
        ret = np.ndarray(shape=(0,1,1,1), dtype=np.float32)
        for data_row in datalist:
            label = np.ndarray(shape=(1,1,1,1), dtype=np.float32)
            label[0,0,0,0] = float(data_row[2]) / float(data_row[5])
            np.append(ret, label, 0)

        pad = self.batch_size - ret.shape[0]
        np.append(ret, np.zeros((pad, 1,1,1)), 0)
        return ret

    def getindex(self):
        """Get index of the current batch.
        Returns
        -------
        index : numpy.array
            The index of current batch
        """
        return self.cursor

    def getpad(self):
        """Get the number of padding examples in current batch.
        Returns
        -------
        pad : int
            Number of padding examples in current batch
        """
        if self.cursor + self.batch_size > self.num_data:
            return self.cursor + self.batch_size - self.num_data
        else:
            return 0

1 个答案:

答案 0 :(得分:1)

numpy.ndarray没有copyto方法。尝试使用mx.ndarray。