我正在尝试从大型numpy数组中提取行。数组的列是obs数,group id(j),time id(t)和一些数据x_jt。
以下是一个例子:
import numpy as np
N = 100
T = 100
X = np.vstack((np.array(range(1,N*T+1)),np.repeat(np.array(range(1,N+1)),T), np.tile(np.array(range(1,T+1)),N), np.random.randint(100,size=N*T))).T
如果我想从组ID为2的X中提取所有行,我会做
X[np.where(X[:,1] == 2)]
如果我想要j = 2或3的所有行,我可以扩展该代码。但是,在我的情况下,我有很多组ID(j' s)要提取。具体来说,我想提取j来自
的所有行samples = np.random.randint(N, size=N) + 1
例如,假设size = 5而不是N,samples =(2,4,5,4,7)。我所追求的是通过X并选择j = 2,然后j = 4,然后j = 5,j = 4,最后j = 7的所有行的代码,并创建一个包含结果的新数组。基本上这个:
result = []
for j in samples:
result.extend(X[np.where(X[:,1] == j)])
但是,当N很大时,此代码很慢。你有任何建议加快它吗?谢谢!
答案 0 :(得分:1)
这可以使用矢量化函数来完成:
def contains(X, samples):
return numpy.vectorize(lambda x: x in samples)(X)
result = X[contains(X[:, 1], set(samples)), :]
如果您想要更换,只需检查每个样品一个计数,直到没有更多样品(假设订单无关紧要)。这样,您至少可以减少迭代矩阵所需的次数。
result = []
sample_counts = collections.Counter(samples)
while sum(sample_counts.itervalues()):
# pick up one of each of the remaining samples and chain their rows
# together in result
s = set(key for key, value in sample_counts.iteritems() if value)
result = itertools.chain(result, X[contains(X[:, 1], s), :])
sample_counts -= collections.Counter(dict.fromkeys(s, 1))
# create a matrix of the final result
result = numpy.array(list(result))
在这种情况下,我能想到的唯一方法可能会加快你已经在做的事情,那就是预先分配一个矩阵。所以你会这样做:
答案 1 :(得分:0)
它并没有完全按照你所描述的那样做,但这类问题是np.in1d
的一个很好的候选者。这样的事情应该有效:
result = X[np.in1d(X[:, 1], samples)]