如何在循环中索引numpy数组时提高速度?

时间:2017-11-01 12:45:47

标签: python numpy

h5id_unique = np.unique(df_[:,0])
cnt_nan = 0
cnt_pos = 0
cnt_neg = 0
cnt = 0
for h5id in h5id_unique:
    dfq_ = df_[df_[:,0]==h5id]
    if dfq_.shape[0] <=2 or dfq_[:,1].sum() != 1:
        cnt_nan += 1
        continue
    cnt += 1
    pos_score = dfq_[dfq_[:,1]==1, 2]
    neg_score = dfq_[dfq_[:,1]==0, 2]
    for i in neg_score:
        if i <= pos_score:
            cnt_pos += 1
        else:
            cnt_neg +=1
    if cnt % 500 == 0:
        print cnt_pos / float(cnt_neg), cnt_nan, cnt

我有一个名为df_的numpy数组,它有三列

h5id, label, pred

h5id是字符串格式的id,而label是0/1 int,pred是浮点数。这是我的代码。 df_有11百万行,而阵列中有300万个不同的h5id。 我发现我的代码很慢。我怎样才能改进它?我认为索引操作需要花费太多时间。它将指数300万次。 感谢。

1 个答案:

答案 0 :(得分:0)

一些建议,但没有完全理解您的代码或运行的数据。 您可以尝试根据第一列对df_进行排序,然后您可以根据相同的键彼此相邻的事实进行更智能的索引。

根据数据,您使用的以下代码提示了一种更简单的方法

if dfq_.shape[0] <=2 or dfq_[:,1].sum() != 1:
    cnt_nan += 1
    continue

这将过滤掉标识符出现两次或更少的任何行,以及第一列和第二列之和不等于1的任何行。如果大多数数据不符合这些条件,那么它将快得多根据这些条件过滤数据,然后处理剩余部分。以下伪代码给出了过滤的粗略方法

vals, idx, count = unique(0.5+randint(2, size=10), return_counts=True, return_inverse=True)
cond1 = (count[idx] > 2)
cond2 = (abs(sum(df_[:, 1:], axis=1) - 1) < 1e-9)

newdf = df_[cond1 & cond2]

然后以与newdf

相同的方式处理df_