用numpy数组中的元组替换整数?

时间:2017-04-04 20:11:28

标签: python arrays numpy

我有一个numpy数组:

a = [[0 1 2 3 4]
     [0 1 2 3 4]
     [0 1 2 3 4]]

我有一个字典,其中包含我想替换/映射的值:

d = { 0 : ( 000, 001 ),
      1 : ( 100, 101 ),
      2 : ( 200, 201 ),
      3 : ( 300, 301 ),
      4 : ( 400, 401 )}

所以我最终得到:

a = [[(000, 001) (100, 101) (200, 201) (300, 301) (400, 401)]
     [(000, 001) (100, 101) (200, 201) (300, 301) (400, 401)]
     [(000, 001) (100, 101) (200, 201) (300, 301) (400, 401)]]

根据this SO answer,基于字典进行值映射的一种方法是:

b = np.copy( a )
for k, v in d.items(): b[ a == k ] = v

当键和值具有相同的数据类型时,此方法有效。但就我而言,关键是int,而新值是tuple (of ints)。因此,我收到cannot assign 2 input values错误。

而不是b = np.copy( a ),我尝试过:

b = a.astype( ( np.int, 2 ) )

但是,我得到ValueError: could not broadcast input array from shape (3,5) into shape (3,5,2)的合理错误。

那么,我怎么能在一个numpy数组中将int映射到元组呢?

2 个答案:

答案 0 :(得分:1)

这个怎么样?

import numpy as np

data = np.tile(np.arange(5), (3, 1))

lookup = { 0 : ( 0, 1 ),
           1 : ( 100, 101 ),
           2 : ( 200, 201 ),
           3 : ( 300, 301 ),
           4 : ( 400, 401 )}

# get keys and values, make sure they are ordered the same
keys, values = zip(*lookup.items())

# making use of the fact that the keys are non negative ints
# create a numpy friendly lookup table
out = np.empty((max(keys) + 1,), object)
out[list(keys)] = values

# now out can be used to look up the tuples using only numpy indexing
result = out[data]
print(result)

打印:

[[(0, 1) (100, 101) (200, 201) (300, 301) (400, 401)]
 [(0, 1) (100, 101) (200, 201) (300, 301) (400, 401)]
 [(0, 1) (100, 101) (200, 201) (300, 301) (400, 401)]]

或者,可能值得考虑使用整数数组:

out = np.empty((max(keys) + 1, 2), int)
out[list(keys), :] = values

result = out[data, :]
print(result)

打印:

[[[  0   1]
  [100 101]
  [200 201]
  [300 301]
  [400 401]]

 [[  0   1]
  [100 101]
  [200 201]
  [300 301]
  [400 401]]

 [[  0   1]
  [100 101]
  [200 201]
  [300 301]
  [400 401]]]

答案 1 :(得分:0)

你可以使用结构化数组(就像使用元组一样,但你不会失去速度优势):

>>> rgb_dtype = np.dtype([('r', np.int64), ('g', np.int64)])
>>> arr = np.zeros(a.shape, dtype=rgb_dtype)
>>> for k, v in d.items():
...     arr[a==k] = v
>>> arr
array([[(  0,   1), (100, 101), (200, 201), (300, 301), (400, 401)],
       [(  0,   1), (100, 101), (200, 201), (300, 301), (400, 401)],
       [(  0,   1), (100, 101), (200, 201), (300, 301), (400, 401)]], 
      dtype=[('r', '<i8'), ('g', '<i8')])

for - 循环可能会被更快的操作所取代。但是,如果您的a包含与总大小相比非常少的不同值,那么这应该足够快。