我正在尝试在Mxnet中使用adagrad进行二进制逻辑回归,但是我的步进函数出错。我的数据集已经在float64中。
错误:
import numpy as np
我的代码:
MXNetError Traceback (most recent call last)
<ipython-input-125-8c2b3ff57944> in <module>()
17 # print(net.weight.data()[0])
18 loss.backward()
---> 19 trainer.step(batch_size)
20 cumulative_loss += nd.sum(loss).asscalar()
21 # print("Epoch %s, loss: %s" % (e, cumulative_loss/1087))
7 frames
/usr/local/lib/python3.6/dist-packages/mxnet/base.py in check_call(ret)
251 """
252 if ret != 0:
--> 253 raise MXNetError(py_str(_LIB.MXGetLastError()))
254
255
MXNetError: [11:00:52] src/operator/contrib/../elemwise_op_common.h:135: Check failed: assign(&dattr, vec.at(i)): Incompatible attr in node at 2-th input: expected float64, got float32
Stack trace:
[bt] (0) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x4b04cb) [0x7fb6783354cb]
[bt] (1) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x556b53) [0x7fb6783dbb53]
[bt] (2) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x7826d0) [0x7fb6786076d0]
[bt] (3) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(mxnet::imperative::SetShapeType(mxnet::Context const&, nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, mxnet::DispatchMode*)+0xf68) [0x7fb67a4e75e8]
[bt] (4) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(mxnet::Imperative::Invoke(mxnet::Context const&, nnvm::NodeAttrs const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&, std::vector<mxnet::NDArray*, std::allocator<mxnet::NDArray*> > const&)+0x1db) [0x7fb67a4f1a0b]
[bt] (5) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x2565409) [0x7fb67a3ea409]
[bt] (6) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(MXImperativeInvokeEx+0x6f) [0x7fb67a3ea9ff]
[bt] (7) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call_unix64+0x4c) [0x7fb6c1a4cdae]
[bt] (8) /usr/lib/x86_64-linux-gnu/libffi.so.6(ffi_call+0x22f) [0x7fb6c1a4c71f]
有人可以告诉我解决该错误的方法吗?