MXNet向后形状不一致

时间:2017-08-23 16:50:48

标签: python mxnet

升级到MXNet 0.11.0后,我的旧代码出现了奇怪的错误:

使用来自scikit-learn的boston外壳数据集:

data = mx.sym.Variable("data")
y = mx.sym.Variable("output_label")
fc = mx.sym.FullyConnected(data=data,num_hidden=20,name='FC1')
fc = mx.sym.Activation(data=fc, act_type='relu', name='act1')
regularization_cost = regularization_cost + mx.sym.sum(mx.sym.abs(fc.get_internals()['FC1_weight']))
fc = mx.sym.FullyConnected(data=fc,num_hidden=1,name='FC2')
regularization_cost = regularization_cost + mx.sym.sum(mx.sym.abs(fc.get_internals()['FC2_weight']))
ce = l1_reg * regularization_cost + 0.5 * mx.sym.mean(mx.symbol.square(fc - y))

train_iter = mx.io.NDArrayIter(data=x_train[:-44], label=y_train[:-44][:, np.newaxis], batch_size=20, shuffle=False,
                              label_name='output_label', last_batch_handle='pad')
mod = mx.mod.Module(symbol=loss,
                context=mx.cpu(),
                data_names=['data'],
                label_names=['output_label'])
mod.fit(train_iter, num_epoch=10, batch_end_callback=f, monitor=mon)

这会引发以下错误:

---------------------------------------------------------------------------
MXNetError                                Traceback (most recent call last)
~/dev/mxnet/python/mxnet/symbol.py in simple_bind(self, ctx, grad_req, type_dict, group2ctx, shared_arg_names, shared_exec, shared_buffer, **kwargs)
   1472                                                  shared_exec_handle,
-> 1473                                                  ctypes.byref(exe_handle)))
   1474         except MXNetError as e:

~/dev/mxnet/python/mxnet/base.py in check_call(ret)
    128     if ret != 0:
--> 129         raise MXNetError(py_str(_LIB.MXGetLastError()))
    130 

MXNetError: [12:45:50] src/pass/infer_shape_type.cc:112: Check failed: rshape[eid] == rshape[idx.entry_id(fnode.inputs[i])] ((1,20) vs. ()) Backward shape inconsistent with the forward shape

Stack trace returned 10 entries:
RuntimeError: simple_bind error. Arguments:
output_label: (20, 1)
data: (20, 13)
[12:45:50] src/pass/infer_shape_type.cc:112: Check failed: rshape[eid] == rshape[idx.entry_id(fnode.inputs[i])] ((1,20) vs. ()) Backward shape inconsistent with the forward shape

1 个答案:

答案 0 :(得分:1)