Numpy数组按两个条件过滤

时间:2013-11-07 20:42:08

标签: python numpy cluster-analysis k-means

我正在尝试运行自定义kmeans聚类算法,并且无法按群集获取2-d numpy数组的每列(term)的文档频率。我当前的算法有两个numpy数组,一个原始数据集按术语[2000L,9500L]列出文档,一个是聚类赋值[2000L,]。共有5个集群。我需要做的是创建一个列出每个集群的文档频率的数组 - 基本上是列数与不同数组中的行号匹配的每列中的计数。输出将是[5L,9500L]阵列(簇x项)。我很难找到一种方法来做相当于一个countif和group by。下面是一些示例数据和我想要的输出,如果我只使用2个集群运行它:

import numpy as np

dataset = np.array[[1,2,0,3,0],[0,2,0,0,3],[4,5,2,3,0],[0,0,2,3,0]]
clusters = np.array[0,1,1,0]
#run code here to get documentFrequency
print documentFrequency
>> [1,1,1,2,0],[1,2,1,1,1]

我的想法是选择与每个群集匹配的特定行,因为这样计算应该很容易。例如,如果我可以将数据拆分为以下数组:

cluster0 = np.array[[1,2,0,3,0],[0,0,2,3,0]]
cluster1 = np.array[[0,2,0,0,3],[4,5,2,3,0]]

任何方向或指针都会非常感激!

3 个答案:

答案 0 :(得分:4)

我认为没有任何简单的方法可以对您的代码进行矢量化,但如果您只有几个群集,则可以做到显而易见:

>>> cluster_count = np.max(clusters)+1
>>> doc_freq = np.zeros((cluster_count, dataset.shape[1]), dtype=dataset.dtype)
>>> for j in xrange(cluster_count):
...     doc_freq[j] = np.sum(dataset[clusters == j], axis=0)
... 
>>> doc_freq
array([[1, 2, 2, 6, 0],
       [4, 7, 2, 3, 3]])

答案 1 :(得分:1)

正如@Jaime所说,如果你只有几个簇,那么使用在最小轴长度上手动循环的常用技巧是有意义的。通常情况下,这可以让你获得完全矢量化的大部分好处,而且很少会让你变得聪明起来。

也就是说,当您发现自己想要groupby时,您经常会在某个域中使用pandas之类的高级工具,非常方便:

>>> pd.DataFrame(dataset).groupby(clusters).sum()
   0  1  2  3  4
0  1  2  2  6  0
1  4  7  2  3  3

如果需要,您可以轻松回归ndarray

>>> pd.DataFrame(dataset).groupby(clusters).sum().values
array([[1, 2, 2, 6, 0],
       [4, 7, 2, 3, 3]])

答案 2 :(得分:0)

根据BLAS编写的编译程序,将此作为矩阵乘法可以更快:

cvals = (clusters == np.arange(clusters.max()+1)[:,None]).astype(int)

cvals
array([[1, 0, 0, 1],
       [0, 1, 1, 0]])

np.dot(cvals,dataset)
array([[1, 2, 2, 6, 0],
       [4, 7, 2, 3, 3]])

让我们创建两个定义:

def loop(cvals,dataset):
     cluster_count = np.max(cvals)+1
     doc_freq = np.zeros((cluster_count, dataset.shape[1]), dtype=dataset.dtype)
     for j in xrange(cluster_count):
         doc_freq[j] = np.sum(dataset[cvals == j], axis=0)
     return doc_freq

def matrix_mult(clusters,dataset):
     cvals = (clusters == np.arange(clusters.max()+1)[:,None]).astype(dataset.dtype)
     return np.dot(cvals,dataset)

现在有一些时间:

arr = np.random.random((2000,9500))
cluster = np.random.randint(0,5,(2000))

np.allclose(loop(cluster,arr),matrix_mult(cluster,arr))
True

%timeit loop(cluster,arr)
1 loops, best of 3: 263 ms per loop

%timeit matrix_mult(cluster,arr)
100 loops, best of 3: 14.1 ms per loop

请注意,这是使用线程mkl BLAS。你的家乡会有所不同。