numpy数组:如何按行检查前X个值是否有效?

时间:2019-02-19 19:28:02

标签: python numpy

问题描述

请考虑以下两个示例数组:

arr = np.array([
    [5.0, 2.0, 1.0, np.nan, np.nan],
    [9.0, np.nan, np.nan, np.nan, 2.0],
    [4.0, 7.0, 4.0, np.nan, np.nan],
    [8.0, np.nan, np.nan, np.nan, np.nan],
    [np.nan, np.nan, np.nan, np.nan, np.nan],
    [np.nan, np.nan, np.nan, np.nan, 6.0]
])

amounts = np.array([
    3,
    1,
    2,
    3,
    0,
    5
])

对于数组arr中的每一行,我想检查该行中的前X个条目是否不是NaN,但所有其他条目都是NaN。 X的每一行的数量都不相同,由数组amounts给出。

所以我的预期结果将是以下布尔数组:

array([ True, False, False, False,  True, False])

到目前为止已尝试

我设法提出了以下工作代码:

result = []
for (row, amount) in zip(arr, amounts):
    if (~np.isnan(row)[:amount]).all() and np.isnan(row)[amount:].all():
        result.append(True)
    else:
        result.append(False)

result = np.array(result)
print(result)

尽管此代码产生了预期的结果,但我仍然觉得它仍然效率低下。我怀疑没有任何for循环的方法是可能的,但是我还没有找到它。

任何人都可以帮助找到针对此问题的完全矢量化解决方案吗?

2 个答案:

答案 0 :(得分:4)

a = np.array([[5.0, 2.0, 1.0, np.nan, np.nan],
              [9.0, np.nan, np.nan, np.nan, 2.0],
              [4.0, 7.0, 4.0, np.nan, np.nan],
              [8.0, np.nan, np.nan, np.nan, np.nan],
              [np.nan, np.nan, np.nan, np.nan, np.nan],
              [np.nan, np.nan, np.nan, np.nan, 6.0]])

b = np.array([3,1,2,3,0,5])

c = np.logical_not(np.isnan(a))
firstn = b == c.argmin(axis=1)
no_extras = b == c.sum(axis=1)
result = np.logical_and(firstn,no_extras)

制作一个非NaN值的布尔数组。

要确保前n个值符合条件; 使用numpy.argmin()查找第一个NaN-将其与counts数组进行比较。

要确保在NaNs start 之后的 中没有任何非NaN值;将布尔数组中所有True的行进行求和,并将其与counts数组进行比较。

and这两个结果。

答案 1 :(得分:2)

您可以这样尝试:

# Values are column numbers
grid = np.tile(np.arange(arr.shape[1]), (arr.shape[0], 1))

# Mask
mask = grid < amounts.reshape((-1, 1))

# Comparison
np.all(~np.isnan(arr) == mask, axis=1)