我有两个numpy数组,其中一个包含大约1%的NaN。
a = np.array([-2,5,nan,6])
b = np.array([2,3,1,0])
我想使用a
的{{3}}计算b
和sklearn
的均方误差。
所以我的问题是,从a
删除所有NaN的python方法是什么,同时尽可能有效地从b
删除所有相应的条目?
答案 0 :(得分:2)
您可以简单地使用香草NumPy的np.nanmean
:
In [136]: np.nanmean((a-b)**2)
Out[136]: 18.666666666666668
如果这个不存在,或者您真的想使用sklearn
方法,则可以创建一个mask来索引NaN:
In [148]: mask = ~np.isnan(a)
In [149]: mean_squared_error(a[mask], b[mask])
Out[149]: 18.666666666666668