numpy数组中的条件索引

时间:2017-10-08 14:06:57

标签: python numpy

我一直在阅读在线教程

from sklearn.decomposition import  * 
from sklearn import datasets
import matplotlib.pyplot as plt
import time

digits=datasets.load_digits()

randomized_pca = PCA(n_components=2,svd_solver='randomized')

# a numpy array with shape= (1800,2)  
reduced_data_rpca = randomized_pca.fit_transform(digits.data)

# make a scatter plot

colors = ['black', 'blue', 'purple', 'yellow', 'pink', 'red', 'lime', 'cyan', 
'orange', 'gray']

start=time.time()

#   Time Taken for this loop = 9.5 seconds

# for i in range(len(reduced_data_rpca)):
#         x = reduced_data_rpca[i][0]
#         y = reduced_data_rpca[i][1]
#         plt.scatter(x,y,c=colors[digits.target[i]])

# Alternative way  TimeTaken = 0.2 sec

# plots all the points (x,y) with color[i] in ith iteration

for i in range(len(colors)):
    """assigns all the elements (accordingly to x and y)  whose label(0-9) equals the variable i (am I 
    correct ? does this mean it iterates the whole again to check for the 
    equality?) """
    x = reduced_data_rpca[:, 0][digits.target == i]  
    y = reduced_data_rpca[:, 1][digits.target == i]
    plt.scatter(x, y, c=colors[i])

end=time.time()

print("Time taken",end-start," Secs")

我的问题是虽然注释和非注释循环都执行相同的操作但我无法理解第二个循环是如何工作的以及为什么它的性能优于另一个循环。

1 个答案:

答案 0 :(得分:1)

您的第一个循环(已注释掉)循环遍历1800个元素的数组。第二个使用numpy的索引方法用于"内部循环"并且只需通过10种颜色的常规for循环。 Numpy数组比常规列表和循环更快。

但是digits.target == i做了什么?在我看来,它不是从reduced_data_rpca中挑选出一个布尔数组,而是一遍又一遍地对字典和数组索引进行比较。这个比较的结果总是False

另见:https://docs.scipy.org/doc/numpy-1.13.0/user/basics.indexing.html