数组的最长前缀等于值

时间:2013-07-29 16:41:02

标签: python numpy

我有NumPy字符串数组,表示序列的分段。 B是感兴趣细分受众群的开始,I其延续,O在任何细分受众群之外。例如,在以下数组中,有三个感兴趣的部分:

>>> y
array(['B', 'I', 'I', 'O', 'B', 'I', 'O', 'O', 'B', 'O'], 
      dtype='|S1')

我可以使用np.where(y == "B")[0]轻松找到细分受众群。但是现在我试图找到段的长度,即最长前缀的长度等于I。我可以使用itertools.takewhile

这样做
>>> from itertools import takewhile
>>> lengths = [1 + sum(1 for _ in takewhile(lambda x: x == "I", y[start + 1:]))
...            for start in np.where(y == "B")[0]]
>>> lengths
[3, 2, 1]

,老实说,工作正常,但有没有一种矢量化的方法来实现这个目标?

2 个答案:

答案 0 :(得分:1)

搜索排序可以在这里提供帮助:

>>> y
array(['B', 'I', 'I', 'O', 'B', 'I', 'O', 'O', 'B', 'O'],
      dtype='|S1')
>>> start=np.where(y=='B')[0]
>>> end=np.where(y=='O')[0]

>>> end[np.searchsorted(end,start)]-start
array([3, 2, 1])

另一种方法:

>>> mask=np.concatenate(([True],(np.diff(end)!=1)))
>>> mask
array([ True,  True, False,  True], dtype=bool)
>>> end[mask]-start
array([3, 2, 1])

答案 1 :(得分:1)

细分以'B'

开头
starts = np.where(y == 'B')[0]

段落在'B''I'后跟'I'以外的其他内容或序列末尾的位置结束:

ends = np.where(((y == 'B') | (y == 'I')) & np.r_[y[1:] != 'I', len(y)])[0]

这给出了段长度:

(ends - starts) + 1
array([3, 2, 1])

编辑:这是一个更简单的方法:在末尾插入一个虚构的B,然后取消(真实的或虚构的)B的位置差异,排除O s :

np.diff(np.where(np.r_[y[y != 'O'], ['B']] == 'B')[0])
array([3, 2, 1])