我有一个Numpy一维向量,例如x = [1, 1, 1, 2, 2, 1, 3, 3, 1]
我必须将其拆分为n个子数组,其中每个向量必须以一个新值开头,并且只要该值相同就必须继续,这样最终的答案就是[[1, 1, 1], [2, 2], [1], [3, 3], [1]]
。
我确实知道必须使用numpy.split()
函数,但是在查找必须拆分的位置时遇到了问题。
我谦虚地请您帮忙,谢谢您的光临!
答案 0 :(得分:1)
您只需要为numpy.split
提供索引即可拆分数组
a = np.array([1,1,1,2,2,1,3,3,1])
np.split(a, np.argwhere(np.diff(a) != 0)[:,0] + 1)
# [array([1, 1, 1]), array([2, 2]), array([1]), array([3, 3]), array([1])]
详细信息
使用np.diff(a)
,您可以得出每个连续元素之间的差异
np.diff(a)
# array([ 0, 0, 1, 0, -1, 2, 0, -2])
差异不等于0的点是元素不连续相同的点。由于您正在寻找需要更改的索引,因此np.diff(a) != 0
会返回:
np.diff(a) != 0
# array([False, False, True, False, True, True, False, True])
要将布尔值转换为索引,可以使用np.argwhere
np.argwhere(np.diff(a) != 0)
# array([[2],[4],[5],[7]])
# since we only need this for 1d arrays
np.argwhere(np.diff(a) != 0)[:,0]
# array([2, 4, 5, 7])
您只需使用上述过程即可为np.split
提供正确的索引
np.split(a, np.argwhere(np.diff(a) != 0)[:,0])
# [array([1, 1]), array([1, 2]), array([2]), array([1, 3]), array([3, 1])]
糟糕...索引错误...似乎我们相差1个索引。没问题,只需将+1添加到np.argwhere
np.split(a, np.argwhere(np.diff(a) != 0)[:,0] + 1)
# [array([1, 1, 1]), array([2, 2]), array([1]), array([3, 3]), array([1])]