在输入和输出整数数组的神经网络中,我应该为PyTorch参数使用哪种dtype?

时间:2018-08-30 17:40:04

标签: integer precision numeric pytorch continuous

我目前正在PyTorch中构建一个神经网络,该神经网络接受 integers 的张量并输出 integers 的张量。输入和输出张量的元素中只有少量“允许”的正整数(例如0、1、2、3和4)。

神经网络通常在连续空间中工作。 例如,层之间的非线性激活函数是连续的,并将整数映射到实数(包括非整数)。

是否最好内部使用torch.uint8之类的无符号整数来表示网络的权重和偏差以及一些将int映射到int的自定义激活函数?

还是我应该使用像torch.float32这样的高精度浮点数,然后通过将实数归类为最接近的整数最后舍入?我认为这是第二种策略,但是也许我错过了一些效果很好的方法。

1 个答案:

答案 0 :(得分:1)

在不了解您的应用程序太多的情况下,我会四舍五入地选择torch.float32。主要原因是,如果您使用GPU计算神经网络,则将需要权重和数据采用float32数据类型。如果您不打算训练神经网络而想在CPU上运行,那么torch.uint8之类的数据类型可能会为您提供帮助,因为您可以在每个时间间隔获得更多的指令(​​即您的应用程序应该运行得更快)。如果那不为您提供任何线索,那么请更详细地说明您的应用程序。