Mxnet数据类型为float64,但一直说它是float32

时间:2020-01-13 08:02:01

标签: python mxnet

我是pytorch和tensorflow用户。为了使用AWS sagemaker的弹性推断,我遇到了Mxnet。

Mxnet gluon数据集api与pytorch的数据集非常相似。

class CustomDataset(mxnet.gluon.data.Dataset):
    def __init__(self):
        self.train_df = pd.read_csv('/shared/KTUTOR/test_summary_data.csv')
    def __getitem__(self, idx):
        return mxnet.nd.array(self.train_df.loc[idx, ['TT', 'TF', 'FT', 'FF']], dtype='float64'), mxnet.nd.array(self.train_df.loc[idx, ['p1']], dtype='float64')
    def __len__(self):
        return len(self.train_df)

我如上所述定义了我的customdataset,并将数据类型设置为float64。

test_data = mxnet.gluon.data.DataLoader(CustomDataset(), batch_size=8, shuffle=True, num_workers=2)

我用DataLoader包装了数据集,到目前为止没有错误。 当我将数据传递到网络时,错误会增加。

for epoch in range(1):
for data, label in test_data:
    print(data.dtype)
    print(label.dtype)
    with autograd.record():
        output = net(data)
        loss = softmax_cross_entropy(output, label)
    loss.backward()
    trainer.step(batch_size)

net(data)中的错误上升,并且错误消息如下所示。

MXNetError: [07:53:55] src/operator/contrib/../elemwise_op_common.h:135: Check failed: assign(&dattr, vec.at(i)): Incompatible attr in node  at 1-th input: expected float64, got float32
Stack trace:
  [bt] (0) /root/anaconda3/lib/python3.7/site-packages/mxnet/libmxnet.so(+0x4b09db) 
[0x7f00f96519db] ...

当我打印数据和标签的类型时,它们都是float64,但是MXNet告诉我数据的数据类型是float32。有人可以解释为什么会这样吗? 提前谢谢。

2 个答案:

答案 0 :(得分:1)

您的网络位于float64还是float32?尝试将权重转换为float64:

net = net.cast('float64')

话虽这么说,以我的经验,在float64中训练DL模型并不常见,float32和float16在训练中更为常见。 MXNet使您可以轻松地使用float16精度来训练explicitly或通过AMP tool (Automatic Mixed Precision)

自动进行训练

答案 1 :(得分:1)

您应该直观地将输入数据转换为float32(而不是float64)。

尽管错误似乎表明与该建议完全相反,但此失败的检查是从网络中的低级操作传播而来的,该操作很可能采用以下形式:(input * weight) + bias

由于input是计算的第一个变量,因此它将其他变量(权重和偏差)的期望数据类型设置为float64。因此,检查实际上在抱怨weight的数据类型是float32,而预期是float64。