如何过滤掉包含列表列中特定子序列的 Pandas DataFrame 中的行?

时间:2021-05-26 12:47:56

标签: python pandas dataframe performance

我有一个如下所示的 DataFrame:

df = pd.DataFrame({"id": [1, 2, 3, 4, 5],
                   "list": [[2, 51, 6, 8, 3], [19, 2, 11, 9], [6, 8, 3, 9, 10, 11], [4, 5], [8, 3, 9, 6]]})

我想过滤这个 DataFrame 使其只包含 X 是子序列的行(所以 X 中元素的顺序与列表中的元素顺序相同,并且它们不与list 列的列表中的其他元素

例如,如果 X = [6, 8, 3],我希望输出如下所示:

id    list
1     [2, 51, 6, 8, 3]
3     [6, 8, 3, 9, 10, 11]

我知道我可以使用以下函数(在 How to check subsequence exists in a list? 上找到)检查一个列表是否是另一个列表的子序列:

def x_in_y(query, base):
    l = len(query)

    for i in range(len(base) - l + 1):
        if base[i:i+l] == query:
            return True
    return False

我有两个问题:

问题 1:

如何将其应用到我的示例中的 Pandas DataFrame 列?

问题 2:

这是最有效的方法吗?如果不是,那会是什么?该函数看起来不是那么优雅/Pythonic,我必须将它应用于大约 200K 行的非常大的 DataFrame。

[注意:list 列中的列表元素是唯一的,应该有助于优化事物]

4 个答案:

答案 0 :(得分:2)

这是列的解决方案调用函数:

df = df[df.list.map(lambda x: x_in_y(X, x))]
#alternative
#df = df[df.list.apply(lambda x: x_in_y(X, x))]
print (df)
   id                  list
0   1      [2, 51, 6, 8, 3]
2   3  [6, 8, 3, 9, 10, 11]

在样本数据中表现非常好,实际上也是最好的测试:

#200k rows
df = pd.concat([df] * 40000, ignore_index=True)
print (df)

X = [6, 8, 3]
x = to_string([6, 8, 3])


In [166]: %timeit df.list.map(lambda x: x_in_y(X, x))
214 ms ± 6.52 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [167]: %timeit df['list'].map(to_string).str.contains(x)
413 ms ± 4.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [168]: %timeit df["list"].apply(has_subsequence, subseq=X)
5.2 s ± 420 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [169]: %timeit df.list.apply(lambda y: ''.join(map(str,X)) in ''.join(map(str,y)))
573 ms ± 116 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

答案 1 :(得分:2)

使用 numpy rolling-window technique

import numpy as np

def rolling_window(a, size):
    a = np.array(a)
    shape = a.shape[:-1] + (a.shape[-1] - size + 1, size)
    strides = a.strides + (a.strides[-1],)
    return np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides)

def has_subsequence(a, subseq):
    return (rolling_window(a, len(subseq)) == subseq).all(axis=1).any()

mask = df["list"].apply(has_subsequence, subseq=[6,8,3])
df[mask]

说明:

rolling_window 使用给定的形状和步幅创建数组视图:

>>> rolling_window([1,2,3,4], 2)
np.array([[1,2], [2,3], [3,4]])

然后我们将结果与我们的目标 X

进行比较
>>> np.array([[1,2], [2,3], [3,4]]) == [2,3]
np.array([[False, False], [True, True], [False, False]])

然后我们告诉 numpy 返回 True,以防所有项目在第 1 轴上都是 True

>>> np.array([[False, False], [True, True], [False, False]]).all(axis=1)
np.array([False, True, False])

如果数组中有True,最后返回True。

>>> np.array([False, True, False]).any()

答案 2 :(得分:2)

你可以试试这个:

import pandas as pd

def to_string(l):
    return '-' + '-'.join(map(str, l)) + '-'

X = to_string([6, 8, 3])
df = pd.DataFrame({"id": [1, 2, 3, 4, 5], "list": [[2, 51, 6, 8, 3], [19, 2, 11, 9], [6, 8, 3, 9, 10, 11], [4, 5], [8, 3, 9, 6]]})

df[df['list'].map(to_string).str.contains(X)]

#    id                  list
# 0   1      [2, 51, 6, 8, 3]
# 2   3  [6, 8, 3, 9, 10, 11]

在我看来,在字符串的开头和结尾添加分隔符很重要。否则,您将遇到列表问题,例如:[666, 8, 3]

答案 3 :(得分:1)

你可以试试这个:

x = [6, 8, 3]
df = df.loc[df.list.apply(lambda y: ''.join(map(str,x)) in ''.join(map(str,y)))]

输出:

   id                  list
0   1      [2, 51, 6, 8, 3]
2   3  [6, 8, 3, 9, 10, 11]