将元素插入numpy数组,以便最小间距是任意的

时间:2018-10-11 21:34:59

标签: python numpy

给定一个有序的numpy浮点数组(最小到最大),我需要确保元素之间的间距小于我称为step的任意浮点。

这是我的代码,据我所知它可以正常工作,但是我想知道是否有更优雅的方法可以做到这一点:

import numpy as np

def slpitArr(arr, step=3.):
    """
    Insert extra elements into array so that the maximum spacing between
    elements is 'step'.
    """
    # Keep going until no more elements need to be added
    while True:
        flagExit = True
        for i, v in enumerate(arr):
            # Catch last element in list
            try:
                if abs(arr[i + 1] - v) > step:
                    new_v = (arr[i + 1] + v) / 2.
                    flagExit = False
                    break
            except IndexError:
                pass
        if flagExit:
            break
        # Insert new element
        arr = np.insert(arr, i + 1, new_v)

    return arr


aa = np.array([10.08, 14.23, 19.47, 21.855, 24.34, 25.02])

print(aa)
print(slpitArr(aa))

结果为:

[10.08  14.23  19.47  21.855 24.34  25.02 ]
[10.08  12.155 14.23  16.85  19.47  21.855 24.34  25.02 ]

2 个答案:

答案 0 :(得分:4)

这是一种单通解决方案,

1)计算连续点之间的差异 d

2)ceil逐步将 d 划分为 m

2a)可选地将 m 舍入为最接近的2的幂

3)将 d 除以 m ,并重复结果 m

4)形成累加和

这是代码。技术说明:d的第一个元素不是差异,而是“锚点”,因此它等于数据的第一个元素。

def fill(data, step, force_power_of_two=True):
    d = data.copy()
    d[1:] -= data[:-1]
    if force_power_of_two:
        m = 1 << (np.frexp(np.nextafter(d / step, -1))[1]).clip(0, None)
    else:
        m = -(d // -step).astype(int)
    m[0] = 1
    d /= m
    return np.cumsum(d.repeat(m))

样品运行:

>>> inp
array([10.08 , 14.23 , 19.47 , 21.855, 24.34 , 25.02 ])
>>> fill(inp, 3)
array([10.08 , 12.155, 14.23 , 16.85 , 19.47 , 21.855, 24.34 , 25.02 ])

答案 1 :(得分:1)

对于有序数组:

def slpitArr(arr, step=3.):
    d = np.ediff1d(arr)
    n = (d / step).astype(dtype=np.int)
    idx = np.flatnonzero(n)
    indices = np.repeat(idx, n[idx]) + 1
    values = np.concatenate(
        [np.linspace(s1, s2, i+1, False)[1:] for s1, s2, i in zip(arr[:-1], arr[1:], n)])
    return np.insert(arr, indices, values)

然后

>>> aa = np.array([10.08, 14.23, 19.47, 21.855, 24.34, 25.02])
>>> print(slpitArr(aa))
[10.08  12.155 14.23  16.85  19.47  21.855 24.34  25.02 ]

>>> print(slpitArr(aa, 2.5))
[10.08       12.155      14.23       15.97666667 17.72333333 19.47
 21.855      24.34       25.02      ]