相当于" case"对于np.where

时间:2016-11-30 02:18:34

标签: python numpy if-statement where

np.where允许您选择要为布尔类型查询分配的值,例如

test = [0,1,2]
np.where(test==0,'True','False')
print test
['True','False','False']

这基本上是一个' if'声明。有没有一种pythonic的方式,如果,如果,否则'一个numpy数组的声明(有不同的情况)?

这是我的解决方法:

color = [0,1,2]
color = np.where(color==0,'red',color)
color = np.where(color==1,'blue',color)
color = np.where(color==2,'green',color)
print color
['red','blue','green']

但我想知道是否有更好的方法可以做到这一点。

2 个答案:

答案 0 :(得分:1)

np.choose属于多元素where

In [97]: np.choose([0,1,1,2,0,1],['red','green','blue'])
Out[97]: 
array(['red', 'green', 'green', 'blue', 'red', 'green'], 
      dtype='<U5')
In [113]: np.choose([0,1,2],[0,np.array([1,2,3])[:,None], np.arange(10,13)])
Out[113]: 
array([[ 0,  1, 12],
       [ 0,  2, 12],
       [ 0,  3, 12]])

在更复杂的情况下,有助于妥善处理广播。

有限制,例如不超过32种选择。它的使用率几乎不到np.where

有时您只想多次应用where或布尔屏蔽:

In [115]: x
Out[115]: 
array([[ 0,  1,  2,  3],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])
In [116]: x[x<4] += 10
In [117]: x
Out[117]: 
array([[10, 11, 12, 13],
       [ 4,  5,  6,  7],
       [ 8,  9, 10, 11]])
In [118]: x[x>8] -=3
In [119]: x
Out[119]: 
array([[ 7,  8,  9, 10],
       [ 4,  5,  6,  7],
       [ 8,  6,  7,  8]])
In [120]: x[(4<x)&(x<8)] *=2
In [121]: x
Out[121]: 
array([[14,  8,  9, 10],
       [ 4, 10, 12, 14],
       [ 8, 12, 14,  8]])

答案 1 :(得分:0)

更多Pythonic方法之一就是使用列表理解,如下所示:

>>> color = [0,1,2]
>>> ['red' if c == 0 else 'blue' if c == 1 else 'green' for c in color]
['red', 'blue', 'green']

如果您阅读它,它会非常直观。对于列表color中的给定项目,如果颜色为'red',则新列表中的值为0,如果'blue'1,则为'green'除此之外,还有if。但是,我不知道我是否会将列表理解中的else for放在三个以上。那里有一个>>> color_dict = {0: 'red', 1: 'blue', 2: 'green'} >>> [color_dict[number] for number in color] ['red', 'blue', 'green'] 循环。

或者您可以使用字典,这可能更多&#34; Pythonic,&#34;并且会更具可扩展性:

library(plot3D); library(plot3Drgl)
with(quakes, 
  scatter3D(x=long, y=lat, z=-depth, colvar=mag, pch=16, cex=1.5, 
  xlab="longitude", ylab="latitude", zlab="depth, km", 
  clab=c("Richter", "Magnitude"), main="Earthquakes off Fiji", 
  ticktype="detailed", theta=10, d=2, 
  colkey=list(length=0.5, width=0.5, cex.clab=1))
)
 plotrgl(lighting = TRUE, smooth = TRUE, cex=2)

with(quakes, 
  scatter3D(x=long, y=lat, z=-depth, colvar=mag, pch=16, cex=1.5, 
  xlab="longitude", ylab="latitude", zlab="depth, km", 
  clab=c("Richter", "Magnitude"), main="Earthquakes off Fiji", 
  ticktype="detailed", theta=10, d=2, 
  colkey=list(length=0.5, width=0.5, cex.clab=2))
)
plotrgl(lighting = TRUE, smooth = TRUE, cex=2)