如何过滤掉含有NaN的子阵列?

时间:2013-11-08 13:29:54

标签: python numpy

让我们假设一个形状(n,5,2)的数组,它在随机位置包含NaN,由以下代码生成:

n = 10
arr = np.random.rand(n, 5, 2)

# replace some values by nan
arr = arr.ravel()
index_array = np.arange(arr.size)
np.random.shuffle(index_array)
arr[index_array[:5]] = np.nan
arr = arr.reshape(n, 5, 2)

如何有效地过滤此数组,以便仅保留那些不包含arr[i] s的NaN?然后,生成的形状为(m,5,2) m<=n

2 个答案:

答案 0 :(得分:4)

无需重塑任何内容:

has_nans = np.isnan(arr).any(axis=(-1,-2))
has_nans 
array([False, False, False,  True,  True,  True, False, False, False,  True], dtype=bool)

>>> arr = arr[~has_nans]
>>> arr.shape
(6, 5, 2)

numpy的旧版本,您需要执行以下操作:

has_nans = np.isnan(arr).any(axis=-1).any(axis=-1)

答案 1 :(得分:0)

这是1班轮:

new = arr[~np.isnan(arr).any((-1,-2))]

print new.shape
Out[10]: (5, 5, 2)