python列之间的所有可能产品

时间:2014-02-26 12:35:03

标签: python numpy itertools

我有一个numpy矩阵X,我想在这个矩阵中添加2个列之间所有可能产品的新变量。

So if X=(x1,x2,x3) I want X=(x1,x2,x3,x1x2,x2x3,x1x3)

有优雅的方式吗? 我认为numpy和itertools的组合应该可以工作

编辑: 非常好的答案,但他们是否认为X是一个矩阵?那么x1,x1,... x3最终可能是数组吗?

编辑: 一个真实的例子

a=array([[1,2,3],[4,5,6]])

3 个答案:

答案 0 :(得分:3)

Itertools应该是答案。

a = [1, 2, 3]
p = (x * y for x, y in itertools.combinations(a, 2))
print list(itertools.chain(a, p))

结果:

[1, 2, 3, 2, 3, 6] # 1, 2, 3, 2 x 1, 3 x 1, 3 x 2

答案 1 :(得分:3)

我认为Samy的解决方案非常好。如果您需要使用numpy,您可以将其转换为这样:

from itertools import combinations
from numpy import prod

x = [1, 2, 3]
print x + map(prod, combinations(x, 2))

提供与Samy解决方案相同的输出:

[1, 2, 3, 2, 3, 6]

答案 2 :(得分:2)

如果你的数组很小,那么使用Samy's pure-Python solutionitertools.combinations就可以了:

from itertools import combinations, chain

def all_products1(a):
    p = (x * y for x, y in combinations(a, 2))
    return list(chain(a, p))

但是如果您的阵列很大,那么使用numpy.triu_indices完全向量化计算将获得大幅加速,如下所示:

import numpy as np

def all_products2(a):
    x, y = np.triu_indices(len(a), 1)
    return np.r_[a, a[x] * a[y]]

让我们比较一下:

>>> data = np.random.uniform(0, 100, (10000,))
>>> timeit(lambda:all_products1(data), number=1)
53.745754408999346
>>> timeit(lambda:all_products2(data), number=1)
12.26144006299728

使用numpy.triu_indices的解决方案也适用于多维数据:

>>> np.random.uniform(0, 100, (3,2))
array([[ 63.75071196,  15.19461254],
       [ 94.33972762,  50.76916376],
       [ 88.24056878,  90.36136808]])
>>> all_products2(_)
array([[   63.75071196,    15.19461254],
       [   94.33972762,    50.76916376],
       [   88.24056878,    90.36136808],
       [ 6014.22480172,   771.41777239],
       [ 5625.39908354,  1373.00597677],
       [ 8324.59122432,  4587.57109368]])

如果要对列而不是行进行操作,请使用:

def all_products3(a):
    x, y = np.triu_indices(a.shape[1], 1)
    return np.c_[a, a[:,x] * a[:,y]]

例如:

>>> np.random.uniform(0, 100, (2,3))
array([[ 33.0062385 ,  28.17575024,  20.42504351],
       [ 40.84235995,  61.12417428,  58.74835028]])
>>> all_products3(_)
array([[   33.0062385 ,    28.17575024,    20.42504351,   929.97553238,
          674.15385734,   575.4909246 ],
       [   40.84235995,    61.12417428,    58.74835028,  2496.45552756,
         2399.42126888,  3590.94440122]])