我正在使用自己的Pascal VOC数据集创建自己的SSD模型。如何解决MultiBoxTarget中的错误。
我过去使用内置数据集(pikachu)检查了该程序,同时处理了自己的数据集,这导致了错误
在这里我调用函数( training_target )
import time
from mxnet import autograd as ag
for epoch in range(start_epoch, epochs):
# reset iterator and tick
cls_metric.reset()
box_metric.reset()
tic = time.time()
# iterate through all batch
for i, batch in enumerate(train_data):
btic = time.time()
# record gradients
with ag.record():
x = batch[0].as_in_context(ctx)
y = batch[1].as_in_context(ctx)
default_anchors, class_predictions, box_predictions = net(x)
box_target, box_mask, cls_target = training_targets(default_anchors, class_predictions, y)
# losses
loss1 = cls_loss(class_predictions, cls_target)
loss2 = box_loss(box_predictions, box_target, box_mask)
# sum all losses
loss = loss1 + loss2
# backpropagate
loss.backward()
# apply
trainer.step(batch_size)
# update metrics
cls_metric.update([cls_target], [nd.transpose(class_predictions, (0, 2, 1))])
box_metric.update([box_target], [box_predictions * box_mask])
if (i + 1) % log_interval == 0:
name1, val1 = cls_metric.get()
name2, val2 = box_metric.get()
print('[Epoch %d Batch %d] speed: %f samples/s, training: %s=%f, %s=%f'
%(epoch ,i, batch_size/(time.time()-btic), name1, val1, name2, val2))
# end of epoch logging
name1, val1 = cls_metric.get()
name2, val2 = box_metric.get()
print('[Epoch %d] training: %s=%f, %s=%f'%(epoch, name1, val1, name2, val2))
print('[Epoch %d] time cost: %f'%(epoch, time.time()-tic))
功能training_targets定义了在 MultiBoxTarget 中出现错误的位置:
from mxnet.contrib.ndarray import MultiBoxTarget
def training_targets(default_anchors, class_predicts, labels):
class_predicts = nd.transpose(class_predicts, axes=(0, 2, 1))
z = MultiBoxTarget(*[default_anchors, labels, class_predicts])
box_target = z[0] # box offset target for (x, y, width, height)
box_mask = z[1] # mask is used to ignore box offsets we don't want to penalize, e.g. negative samples
cls_target = z[2] # cls_target is an array of labels for all anchors boxes
return box_target, box_mask, cls_target
预期输出:训练模型并保存net.save_parameters('ssd_%d.params'%历元)
实际输出:
MXNetError Traceback (most recent call last)
<ipython-input-80-6e8fe42e4df5> in <module>()
16 default_anchors, class_predictions, box_predictions = net(x)
17 print(y.shape)
---> 18 box_target, box_mask, cls_target = training_targets(default_anchors, class_predictions, y)
19 # losses
20 loss1 = cls_loss(class_predictions, cls_target)
<ipython-input-68-866caabcf8c9> in training_targets(default_anchors, class_predicts, labels)
2 def training_targets(default_anchors, class_predicts, labels):
3 class_predicts = nd.transpose(class_predicts, axes=(0, 2,
1))
----> 4 z = MultiBoxTarget(*[default_anchors, labels, class_predicts])
5 box_target = z[0] # box offset target for (x, y, width, height)
6 box_mask = z[1] # mask is used to ignore box offsets we don't want to penalize, e.g. negative samples
/usr/local/lib/python3.6/dist-packages/mxnet/ndarray/register.py in MultiBoxTarget(anchor, label, cls_pred, overlap_threshold, ignore_label, negative_mining_ratio, negative_mining_thresh, minimum_negative_samples, variances, out, name, **kwargs)
/usr/local/lib/python3.6/dist-packages/mxnet/_ctypes/ndarray.py in _imperative_invoke(handle, ndargs, keys, vals, out)
90 c_str_array(keys),
91 c_str_array([str(s) for s in vals]),
---> 92 ctypes.byref(out_stypes)))
93
94 if original_output is not None:
/usr/local/lib/python3.6/dist-packages/mxnet/base.py in check_call(ret)
250 """
251 if ret != 0:
--> 252 raise MXNetError(py_str(_LIB.MXGetLastError()))
253
254
**MXNetError: [08:50:36] include/mxnet/operator.h:228: Check failed: in_type->at(i) == mshadow::default_type_flag || in_type->at(i) == -1 Unsupported data type 1**
Stack trace returned 10 entries:
[bt] (0) /usr/local/lib/python3.6/dist-
packages/mxnet/libmxnet.so(+0x23d55a) [0x7f454cf5555a]
[bt] (1) /usr/local/lib/python3.6/dist-
packages/mxnet/libmxnet.so(+0x23dbc1) [0x7f454cf55bc1]
[bt] (2) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x2fb9dd) [0x7f454d0139dd]
[bt] (3) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x2e0c0b5) [0x7f454fb240b5]
[bt] (4) /usr/local/lib/python3.6/dist- packages/mxnet/libmxnet.so(mxnet::imperative::SetShapeType(mxnet::Context const&, nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, mxnet::DispatchMode*)+0x1274) [0x7f454f936814]
[bt] (5) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(mxnet::Imperative::Invoke(mxnet::Context const&, nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&)+0x309) [0x7f454f9400b9]
[bt] (6) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x2b2d8b9) [0x7f454f8458b9]
[bt] (7) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(MXImperativeInvokeEx+0x6f) [0x7f454f845eaf]
[bt] (8) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call_unix64+0x4c) [0x7f45761e6dae]
[bt] (9) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call+0x22f) [0x7f45761e671f]