如何在自定义数据集上训练GluonCV的Faster RCNN?

我在[1]中改编了GluonCV教程的完整培训脚本,并参考了[2]中针对SSD的完整培训脚本。我从[3]中的训练脚本中成功地改编了YoloV3的示例,并成功地训练了SSD和YoloV3。这些脚本是为多GPU设计的,因此我简化了Jupyter Notebook中所有3个脚本的单GPU执行脚本。

但是,当我使用本教程中的Batchify函数时,Faster RCNN失败了:

train_bfn = batchify.Tuple(*[batchify.Append() for _ in range(5)]) 


train_bfn = batchify.FasterRCNNTrainBatchify(net) 

def get_dataloader(net, train_dataset, batch_size, num_workers):

    train_bfn = batchify.Tuple(*[batchify.Append() for _ in range(5)]) 
    #train_bfn = batchify.FasterRCNNTrainBatchify(net) #the train_bfn discovered from the gluoncv repos which ba

    #adding and removing train sampler didn't fix the loading problem
    #train_sampler = gcv.nn.sampler.SplitSampler(len(train_dataset),1)
    train_loader = mx.gluon.data.DataLoader(
        train_dataset.transform(FasterRCNNDefaultTrainTransform(net.short, net.max_size, net, ashape=net.ashape, multi_stage=True)),
        batch_size, False, batchify_fn=train_bfn,
        last_batch='rollover', num_workers=num_workers)
    return train_loader


#unusual loading method for Faster-RCNN
def split_and_load(batch, ctx_list):
    """Split data to 1 batch each device."""
    num_ctx = len(ctx_list)
    new_batch = []
    for i, data in enumerate(batch):
        new_data = [x.as_in_context(ctx) for x, ctx in zip(data, ctx_list)]
    return new_batch

print(f"Starting Training @ {start_epoch+1}/{end_epoch} Epochs. Reporting loss every {batch_reporter} batch")

start_time = time.time()
for epoch in range(start_epoch, end_epoch):
    rcnn_task = ForwardBackwardTask(net, trainer, rpn_cls_loss, rpn_box_loss, rcnn_cls_loss, rcnn_box_loss)
    executor = Parallel(1, rcnn_task)  
    for metric in rcnn_losses:
    tic = time.time()
    btic = time.time()
    tic = time.time() #currently unused
    btic = time.time() #batch time
    #setup the network into static computation graph for faster computation
    net.hybridize(static_alloc=True, static_shape=True)
    #main training loop, batch
    for i, batch in enumerate(train_batcher):
        #get batch size  
        batch = split_and_load(batch, ctx_list=ctx)
        batch_size = len(batch[0])
        batch_losses = [[] for _ in rcnn_losses]
        batch_metrics = [[] for _ in rcnn_metrics]
        #load the gpu context into the data, class targets and box targets
        for data in zip(*batch):
        #iteratte over the contexts
        for j in range(len(ctx)):
            result = executor.get()
            for k in range(len(metric_losses)):
            for k in range(len(add_losses)):
                batch_mertrics[k].append(result[len(batch_losses) + k])
        for metric, record in zip(rcnn_losses, batch_losses):
            metric.update(0, record)
        for metric, records in zip(rcnn_metrics, batch_metrics):
            for pred in records:
                metric.update(pred[0], pred[1])
         #update traininer

        if i % batch_reporter == 0:
            msg = ','.join(
                ['{}={:.3f}'.format(*metric.get()) for metric in rcnn_losses + rcnn_metrics])
            logger.info('[Epoch {}][Batch {}], Speed: {:.3f} samples/sec, {}'.format(
                epoch, i, batch_size / (time.time() - btic), msg))
            btic = time.time()


train_bfn = batchify.Tuple(*[batchify.Append() for _ in range(5)])
Starting Training @ 1/100 Epochs. Reporting loss every 20 batch
ValueError                                Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/mxnet/ndarray/ndarray.py in array(source_array, ctx, dtype)
   2500             try:
-> 2501                 source_array = np.array(source_array, dtype=dtype)
   2502             except:

ValueError: setting an array element with a sequence.

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
7 frames
<ipython-input-17-7873cb0811cc> in <module>()
     50     net.hybridize(static_alloc=True, static_shape=True)
     51     #main training loop, batch
---> 52     for i, batch in enumerate(train_batcher):
     53         #get batch size
     54         batch = split_and_load(batch, ctx_list=ctx)

/usr/local/lib/python3.6/dist-packages/mxnet/gluon/data/dataloader.py in same_process_iter()
    573             def same_process_iter():
    574                 for batch in self._batch_sampler:
--> 575                     ret = self._batchify_fn([self._dataset[idx] for idx in batch])
    576                     if self._pin_memory:
    577                         ret = _as_in_context(ret, context.cpu_pinned(self._pin_device_id))

/usr/local/lib/python3.6/dist-packages/gluoncv/data/batchify.py in __call__(self, data)
    378         ret = []
    379         for i, ele_fn in enumerate(self._fn):
--> 380             ret.append(ele_fn([ele[i] for ele in data]))
    381         return tuple(ret)

/usr/local/lib/python3.6/dist-packages/gluoncv/data/batchify.py in __call__(self, data)
    294         """
    295         return _append_arrs(data, use_shared_mem=True,
--> 296                             expand=self._expand, batch_axis=self._batch_axis)

/usr/local/lib/python3.6/dist-packages/gluoncv/data/batchify.py in _append_arrs(arrs, use_shared_mem, expand, batch_axis)
     92     else:
     93         if use_shared_mem:
---> 94             out = [mx.nd.array(x, ctx=mx.Context('cpu_shared', 0)) for x in arrs]
     95         else:
     96             out = [mx.nd.array(x) for x in arrs]

/usr/local/lib/python3.6/dist-packages/gluoncv/data/batchify.py in <listcomp>(.0)
     92     else:
     93         if use_shared_mem:
---> 94             out = [mx.nd.array(x, ctx=mx.Context('cpu_shared', 0)) for x in arrs]
     95         else:
     96             out = [mx.nd.array(x) for x in arrs]

/usr/local/lib/python3.6/dist-packages/mxnet/ndarray/utils.py in array(source_array, ctx, dtype)
    144         return _sparse_array(source_array, ctx=ctx, dtype=dtype)
    145     else:
--> 146         return _array(source_array, ctx=ctx, dtype=dtype)

/usr/local/lib/python3.6/dist-packages/mxnet/ndarray/ndarray.py in array(source_array, ctx, dtype)
   2501                 source_array = np.array(source_array, dtype=dtype)
   2502             except:
-> 2503                 raise TypeError('source_array must be array like object')
   2504     arr = empty(source_array.shape, ctx, dtype)
   2505     arr[:] = source_array

TypeError: source_array must be array like object


train_bfn = batchify.FasterRCNNTrainBatchify(net)
Starting Training @ 1/100 Epochs. Reporting loss every 20 batch
MXNetError                                Traceback (most recent call last)
<ipython-input-48-7873cb0811cc> in <module>()
     58         #load the gpu context into the data, class targets and box targets
     59         for data in zip(*batch):
---> 60             executor.put(data)
     61         #iteratte over the contexts
     62         for j in range(len(ctx)):

6 frames
/usr/local/lib/python3.6/dist-packages/gluoncv/utils/parallel.py in put(self, x)
    117         if self._num_serial > 0 or len(self._threads) == 0:
    118             self._num_serial -= 1
--> 119             out = self._parallizable.forward_backward(x)
    120             self._out_queue.put(out)
    121         else:

<ipython-input-47-9cc4d320d600> in forward_backward(self, x)
     18         with autograd.record():
---> 19             gt_label = label[:, :, 4:5]
     20             gt_box = label[:, :, :4]
     21             cls_pred, box_pred, roi, samples, matches, rpn_score, rpn_box, anchors = net(

/usr/local/lib/python3.6/dist-packages/mxnet/ndarray/ndarray.py in __getitem__(self, key)
    509         indexing_dispatch_code = _get_indexing_dispatch_code(key)
    510         if indexing_dispatch_code == _NDARRAY_BASIC_INDEXING:
--> 511             return self._get_nd_basic_indexing(key)
    512         elif indexing_dispatch_code == _NDARRAY_ADVANCED_INDEXING:
    513             return self._get_nd_advanced_indexing(key)

/usr/local/lib/python3.6/dist-packages/mxnet/ndarray/ndarray.py in _get_nd_basic_indexing(self, key)
    821                                  'index=%s of type=%s.' % (str(slice_i), str(type(slice_i))))
    822         kept_axes.extend(range(i+1, len(shape)))
--> 823         sliced_nd = op.slice(self, begin, end, step)
    824         if len(kept_axes) == len(shape):
    825             return sliced_nd

/usr/local/lib/python3.6/dist-packages/mxnet/ndarray/register.py in slice(data, begin, end, step, out, name, **kwargs)

/usr/local/lib/python3.6/dist-packages/mxnet/_ctypes/ndarray.py in _imperative_invoke(handle, ndargs, keys, vals, out)
     90         c_str_array(keys),
     91         c_str_array([str(s) for s in vals]),
---> 92         ctypes.byref(out_stypes)))
     94     if original_output is not None:

/usr/local/lib/python3.6/dist-packages/mxnet/base.py in check_call(ret)
    251     """
    252     if ret != 0:
--> 253         raise MXNetError(py_str(_LIB.MXGetLastError()))

MXNetError: [10:50:45] src/operator/tensor/./matrix_op-inl.h:657: Check failed: param_begin.ndim() <= dshape.ndim() (3 vs. 2) : Slicing axis exceeds data dimensions
Stack trace:
  [bt] (0) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x4a357b) [0x7f5fce81657b]
  [bt] (1) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x23fd387) [0x7f5fd0770387]
  [bt] (2) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x2402f96) [0x7f5fd0775f96]
  [bt] (3) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(mxnet::imperative::SetShapeType(mxnet::Context const&, nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, mxnet::DispatchMode*)+0x1fb1) [0x7f5fd0a6acc1]
  [bt] (4) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(mxnet::Imperative::Invoke(mxnet::Context const&, nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&)+0x1db) [0x7f5fd0a7493b]
  [bt] (5) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x25ffd99) [0x7f5fd0972d99]
  [bt] (6) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(MXImperativeInvokeEx+0x6f) [0x7f5fd097338f]
  [bt] (7) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call_unix64+0x4c) [0x7f601f94edae]
  [bt] (8) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call+0x22f) [0x7f601f94e71f]

