根据其元素拆分Numpy数组,其中数组的每个元素都是唯一的

时间:2018-10-20 14:56:16

标签: python numpy vector split

我有一个Numpy一维向量,例如x = [1, 1, 1, 2, 2, 1, 3, 3, 1]

我必须将其拆分为n个子数组,其中每个向量必须以一个新值开头,并且只要该值相同就必须继续,这样最终的答案就是[[1, 1, 1], [2, 2], [1], [3, 3], [1]]

我确实知道必须使用numpy.split()函数,但是在查找必须拆分的位置时遇到了问题。

我谦虚地请您帮忙,谢谢您的光临!

1 个答案:

答案 0 :(得分:1)

您只需要为numpy.split提供索引即可拆分数组

a = np.array([1,1,1,2,2,1,3,3,1])
np.split(a, np.argwhere(np.diff(a) != 0)[:,0] + 1)
# [array([1, 1, 1]), array([2, 2]), array([1]), array([3, 3]), array([1])]

详细信息

使用np.diff(a),您可以得出每个连续元素之间的差异

np.diff(a)
# array([ 0,  0,  1,  0, -1,  2,  0, -2])

差异不等于0的点是元素不连续相同的点。由于您正在寻找需要更改的索引,因此np.diff(a) != 0会返回:

np.diff(a) != 0
# array([False, False,  True, False,  True,  True, False,  True])

要将布尔值转换为索引,可以使用np.argwhere

np.argwhere(np.diff(a) != 0)
# array([[2],[4],[5],[7]])
# since we only need this for 1d arrays
np.argwhere(np.diff(a) != 0)[:,0]
# array([2, 4, 5, 7])

您只需使用上述过程即可为np.split提供正确的索引

np.split(a, np.argwhere(np.diff(a) != 0)[:,0])
# [array([1, 1]), array([1, 2]), array([2]), array([1, 3]), array([3, 1])]

糟糕...索引错误...似乎我们相差1个索引。没问题,只需将+1添加到np.argwhere

的结果中
np.split(a, np.argwhere(np.diff(a) != 0)[:,0] + 1)
# [array([1, 1, 1]), array([2, 2]), array([1]), array([3, 3]), array([1])]