numpy.where使代码变慢

时间:2017-06-13 17:02:42

标签: python performance numpy

我有以下代码块:

def hasCleavage(tags, pair, fragsize):
    limit = int(fragsize["mean"] + fragsize["sd"] * 4)
    if pair.direction == "F1R2" or pair.direction == "R2F1":
        x1 = np.where((tags[pair.chr_r1] >= pair.r1["pos"]) & (tags[pair.chr_r1] <= pair.r1["pos"]+limit))[0]
        x2 = np.where((tags[pair.chr_r2] <= pair.r2["pos"]+pair.frside) & (tags[pair.chr_r2] >= pair.r2["pos"]+pair.frside-limit))[0]
    elif pair.direction == "F1F2" or pair.direction == "F2F1":
        x1 = np.where((tags[pair.chr_r1] >= pair.r1["pos"]) & (tags[pair.chr_r1] <= pair.r1["pos"]+limit))[0]
        x2 = np.where((tags[pair.chr_r2] >= pair.r2["pos"]) & (tags[pair.chr_r2] <= pair.r2["pos"]+limit))[0]
    elif pair.direction == "R1R2" or pair.direction == "R2R1":
        x1 = np.where((tags[pair.chr_r1] <= pair.r1["pos"]+pair.frside) & (tags[pair.chr_r1] >= pair.r1["pos"]+pair.frside-limit))[0]
        x2 = np.where((tags[pair.chr_r2] <= pair.r2["pos"]+pair.frside) & (tags[pair.chr_r2] >= pair.r2["pos"]+pair.frside-limit))[0]
    else: #F2R1 or R1F2
        x1 = np.where((tags[pair.chr_r2] >= pair.r2["pos"]) & (tags[pair.chr_r2] <= pair.r2["pos"]+limit))[0]
        x2 = np.where((tags[pair.chr_r1] <= pair.r1["pos"]+pair.frside) & (tags[pair.chr_r1] >= pair.r1["pos"]+pair.frside-limit))[0]
    if x1.size > 0 and x2.size > 0:
        return True
    else:
        return False

我的脚本需要 16分钟才能完成。它调用hasCleavage数百万次,每行一次读取文件。当我在上面添加变量limitreturn True(阻止调用np.where)时,脚本需要 5分钟

tags是一个字典,包含带有升序数字的numpy数组。

您有什么建议可以改善表现吗?

编辑:

tags = {'JH584302.1': array([   351,   1408,   2185,   2378,   2740,   2904,   3364,   3657,
         4240,   5324,   5966,   5977,   5986,   6488,   6531,   6847,
         6961,   6973,   6991,   7107,   7383,   7395,   7557,   7569,
         9178,  10077,  10456,  10471,  11271,  11466,  12311,  12441,
        12598,  13051,  13123,  13859,  14167,  14672,  15156,  15252,
        15268,  15273,  15694,  15786,  16361,  17073,  17293,  17454])
}
fragsize = {'sd': 130.29407997430428, 'mean': 247.56636}

pair是自定义类的对象 <__main__.Pair object at 0x17129ad0>

0 个答案:

没有答案