优化嵌套的for循环

时间:2019-08-05 18:00:04

标签: python numpy

我有一个熊猫数据框,其中包含一系列A,B,C,D列(0或1)以及一系列包含它们之间的相互作用的AB,AC,BC,CD列(也为0或1)。

基于这些相互作用,我想像以下MWE中那样建立“三胞胎” ABC,ABD,ACD,BCD的存在:

import numpy as np
import pandas as pd
df = pd.DataFrame()

np.random.seed(1)

df["A"] = np.random.randint(2, size=10)
df["B"] = np.random.randint(2, size=10)
df["C"] = np.random.randint(2, size=10)
df["D"] = np.random.randint(2, size=10)

df["AB"] = np.random.randint(2, size=10)
df["AC"] = np.random.randint(2, size=10)
df["AD"] = np.random.randint(2, size=10)
df["BC"] = np.random.randint(2, size=10)
df["BD"] = np.random.randint(2, size=10)
df["CD"] = np.random.randint(2, size=10)

ls = ["A", "B", "C", "D"]
for i, a in enumerate(ls):
    for j in range(i + 1, len(ls)):
        b = ls[j]
        for k in range(j + 1, len(ls)):
            c = ls[k]
            idx = a+b+c

            idx_abc = (df[a]>0) & (df[b]>0) & (df[c]>0)
            sum_abc = df[idx_abc][a+b] + df[idx_abc][b+c] + df[idx_abc][a+c]

            df[a+b+c]=0
            df.loc[sum_abc.index[sum_abc>=2], a+b+c] = 999

这将提供以下输出:

   A  B  C  D  AB  AC  AD  BC  BD  CD  ABC  ABD  ACD  BCD
0  1  0  0  0   1   0   0   1   1   0    0    0    0    0
1  1  1  1  0   1   1   1   1   0   0  999    0    0    0
2  0  0  0  1   1   0   1   0   0   1    0    0    0    0
3  0  1  0  1   1   0   0   0   1   1    0    0    0    0
4  1  1  1  1   1   1   1   0   1   1  999  999  999  999
5  1  0  0  1   1   1   1   0   0   0    0    0    0    0
6  1  0  0  1   0   1   1   1   1   1    0    0    0    0
7  1  1  0  0   1   0   1   1   1   1    0    0    0    0
8  1  0  1  0   1   1   0   1   0   0    0    0    0    0
9  0  0  0  0   0   0   0   0   1   1    0    0    0    0

代码的逻辑如下:如果AB,AC,BC列中至少有两个处于活动状态(= 1),则三元组ABC处于活动状态(= 1) A,B,C均处于活动状态(= 1)。

我总是从查看各个列开始(对于ABC,则是A,B和C)。查看A,B和C列,我们仅“保留” A,B和C均为非零的行。然后,查看AB,AC和BC的交互作用,如果AB,AC和BC中至少有两个为1,则仅“启用”三元组ABC-它们仅用于第1行和第4行!因此,对于第1行和第4行,ABC = 999;对于所有其他行,ABC = 0。我会为所有可能的三元组(在这种情况下为4个)执行此操作。

由于数据帧较小,因此上面的代码运行速度很快。但是,在我的真实代码中,数据框具有超过一百万行和数百次交互,在这种情况下,它运行非常慢。

有没有一种方法可以优化上述代码,例如通过多线程吗?

1 个答案:

答案 0 :(得分:2)

这是比参考代码快10倍的方法。它没有做任何特别聪明的事情,只是行人优化。

import numpy as np
import pandas as pd
df = pd.DataFrame()

np.random.seed(1)

df["A"] = np.random.randint(2, size=10)
df["B"] = np.random.randint(2, size=10)
df["C"] = np.random.randint(2, size=10)
df["D"] = np.random.randint(2, size=10)

df["AB"] = np.random.randint(2, size=10)
df["AC"] = np.random.randint(2, size=10)
df["AD"] = np.random.randint(2, size=10)
df["BC"] = np.random.randint(2, size=10)
df["BD"] = np.random.randint(2, size=10)
df["CD"] = np.random.randint(2, size=10)

ls = ["A", "B", "C", "D"]

def op():
    out = df.copy()
    for i, a in enumerate(ls):
        for j in range(i + 1, len(ls)):
            b = ls[j]
            for k in range(j + 1, len(ls)):
                c = ls[k]
                idx = a+b+c

                idx_abc = (out[a]>0) & (out[b]>0) & (out[c]>0)
                sum_abc = out[idx_abc][a+b] + out[idx_abc][b+c] + out[idx_abc][a+c]

                out[a+b+c]=0
                out.loc[sum_abc.index[sum_abc>=2], a+b+c] = 99
    return out

import scipy.spatial.distance as ssd

def pp():
    data = df.values
    n = len(ls)
    d1,d2 = np.split(data, [n], axis=1)
    i,j = np.triu_indices(n,1)
    d2 = d2 & d1[:,i] & d1[:,j]
    k,i,j = np.ogrid[:n,:n,:n]
    k,i,j = np.where((k<i)&(i<j))
    lu = ssd.squareform(np.arange(n*(n-1)//2))
    d3 = ((d2[:,lu[k,i]]+d2[:,lu[i,j]]+d2[:,lu[k,j]])>=2).view(np.uint8)*99
    *triplets, = map("".join, combinations(ls,3))
    out = df.copy()
    out[triplets] = pd.DataFrame(d3, columns=triplets)
    return out

from string import ascii_uppercase
from itertools import combinations, chain

def make(nl=8, nr=1000000, seed=1):
    np.random.seed(seed)
    letters = np.fromiter(ascii_uppercase, 'U1', nl)
    df = pd.DataFrame()
    for l in chain(letters, map("".join,combinations(letters,2))):
        df[l] = np.random.randint(0,2,nr,dtype=np.uint8)
    return letters, df

df1 = op()
df2 = pp()
assert (df1==df2).all().all()

ls, df = make(8,1000)

df1 = op()
df2 = pp()
assert (df1==df2).all().all()

from timeit import timeit

print(timeit(op,number=10))
print(timeit(pp,number=10))

ls, df = make(26,250000)
import time

t0 = time.perf_counter()
df2 = pp()
t1 = time.perf_counter()
print(t1-t0)

样品运行:

3.2022583668585867 # op 8 symbols, 1000 rows, 10 repeats
0.2772211490664631 # pp 8 symbols, 1000 rows, 10 repeats
12.412292044842616 # pp 26 symbols, 250,000 rows, single run