如何阻止numpy meshgrid将默认数据类型设置为int64

时间:2017-02-07 13:56:16

标签: python numpy scipy

我必须使用numpy meshgrid创建一个非常大的网格。为了节省内存,我使用int8作为我尝试网格化的数组的dtype。但是,meshgrid不断将类型更改为使用大量内存的int64。这是一个简单的问题示例......

import numpy

grids = [numpy.arange(1, 4, dtype=numpy.int8), numpy.arange(1, 5, dtype=numpy.int8)]

print grids
print grids[0].dtype, grids[0].nbytes

x1, y1 = numpy.meshgrid(*grids)

print x1.dtype, x1.nbytes

此脚本打印

[array([1, 2, 3], dtype=int8), array([1, 2, 3, 4], dtype=int8)]

int8 3

int64 96

为什么meshgrid会这样做?有什么方法可以阻止它吗?我需要创建一个巨大的数组,所以我不能使用meshgrid,除非我可以控制输出的数据类型。这是预期的行为还是一个numpy bug?我在numpy中使用的所有其他函数都保留了数据类型或允许您使用dtype参数更改它。 meshgrid函数似乎不允许这样做。

1 个答案:

答案 0 :(得分:4)

您可以将numpy.meshgrid()的可选copy参数设置为False(但请注意,它有一些限制条件):

  

meshgrid(*xi, **kwargs)

     

...

     

copybool,可选

     

如果False,则返回原始数组的视图   保存记忆。默认值为True。请注意sparse=False,   copy=False可能会返回非连续数组。此外,   广播阵列的多个元素可以指单个元素   记忆位置。如果需要写入数组,请进行复制   第一

证明它有效:

>>> import numpy
>>> 
>>> grids = [numpy.arange(1, 4, dtype=numpy.int8), numpy.arange(1, 5, dtype=numpy.int8)]
>>> 
>>> print grids
[array([1, 2, 3], dtype=int8), array([1, 2, 3, 4], dtype=int8)]
>>> print grids[0].dtype, grids[0].nbytes
int8 3
>>>
>>> x1, y1 = numpy.meshgrid(*grids, copy=False)
>>>                        #        ^^^^^^^^^^
>>> print x1.dtype, x1.nbytes
int8 12