使用numpy> = 1.13查看特定字段为ndarray

时间:2017-11-29 10:31:54

标签: python numpy scipy

数据位于结构化数组中:

import numpy as np
dtype = [(field, float) for field in ['x', 'y', 'z', 'prop1', 'prop2']]
data = np.array([(1,2,3,4,5), (6,7,8,9,10), (11,12,13,14,15)], dtype=dtype)

对于某些操作,位置作为单个nx3阵列访问,例如:

positions = data[['x', 'y', 'z']].view(dtype=float).reshape(-1, 3)
ranges = np.sqrt(np.sum(positions**2, 1))

从numpy 1.12开始,会发出以下警告:

  

FutureWarning:Numpy检测到您可能正在查看或写入通过选择a中的多个字段返回的数组   结构化数组。

     

此代码可能会在numpy 1.13中中断,因为这将返回视图而不是副本 - 有关详细信息,请参阅发行说明。

Here是发行说明中的​​相应条目:

  

使用多个字段(例如arr[['f1', 'f3']])索引结构化数组会将视图返回到1.13中的原始数组,而不是副本。请注意,与1.12中的副本不同,返回的视图将具有与原始数组中的中间字段对应的额外填充字节,这将影响arr[['f1', 'f3']].view(newdtype)等代码。

如何将此代码移植到numpy> = 1.13?

2 个答案:

答案 0 :(得分:2)

检查numpy 1.13已公布的变更似乎尚未发生。因此,让我们模拟未来:

未来的行为可能不会复制数据,而是创建一个只包含您想要的字段的dtype,而不是原始dtype的itemsize。因此,每个元素,未使用的内存部分都会有间隙。

xyz_tp = xyz_tp = np.dtype({'names': list('xyz'),
                            'formats': tuple(data.dtype.fields[f][0] for f in 'xyz'),
                            'offsets': tuple(data.dtype.fields[f][1] for f in 'xyz'), 
                            'itemsize': data.dtype.itemsize})

xyz = data.view(xyz_tp)
xyz
# array([(  1.,   2.,   3.), (  6.,   7.,   8.), ( 11.,  12.,  13.)],
#       dtype={'names':['x','y','z'], 'formats':['<f8','<f8','<f8'], 'offsets':[0,8,16], 'itemsize':40})

未使用的内存位置及其内容会被忽略但仍然存在,因此如果您使用内置dtype进行查看,则会再次出现。

xyz.view(float)
# array([  1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,  11.,
#         12.,  13.,  14.,  15.])
# Ouch!

一般修复方法是使用相同的字段强制转换为连续(无间隙)dtype。这将强制复制

xyz_cont_tp = np.dtype({'names': list('xyz'), 'formats': 3*('<f8',)})
xyz.astype(xyz_cont_tp).view(float).reshape(-1, 3)
# array([[  1.,   2.,   3.],
#        [  6.,   7.,   8.],
#        [ 11.,  12.,  13.]])

在所选字段连续且类型相同的特殊情况下,您还可以执行以下操作:

np.lib.stride_tricks.as_strided(data.view(float), shape=(3,3), strides=data.strides + (8,))
# array([[  1.,   2.,   3.],
#        [  6.,   7.,   8.],
#        [ 11.,  12.,  13.]])

此方法不会复制数据,但会创建真实视图。

答案 1 :(得分:0)

其他几种相邻浮点字段的方式。这里对于从'x'开始的3个字段,我们得到相同的结果:

np.ndarray((len(data),3), float, data, offset= data.dtype.fields['x'][1], strides= (data.strides[0], np.dtype(float).itemsize))