带有GroupBy的xarray.apply_ufunc():意外的尺寸数

时间:2018-11-01 20:05:52

标签: python-xarray

我正在使用xarray.apply_ufunc()将函数应用于xarray.DataArray。它在某些NetCDF上可以很好地工作,而在尺寸,坐标等方面可比的其他NetCDF上却不能工作。但是,代码所适用的NetCDF与代码失败的NetCDF之间必须有所不同,希望有人可以在查看下面列出的文件的代码和一些元数据后,对问题出了些评论。

我正在执行的代码是这样的:

# open the precipitation NetCDF as an xarray DataSet object
dataset = xr.open_dataset(kwrgs['netcdf_precip'])

# get the precipitation array, over which we'll compute the SPI
da_precip = dataset[kwrgs['var_name_precip']]

# stack the lat and lon dimensions into a new dimension named point, so at each lat/lon
# we'll have a time series for the geospatial point, and group by these points
da_precip_groupby = da_precip.stack(point=('lat', 'lon')).groupby('point')

# apply the SPI function to the data array
da_spi = xr.apply_ufunc(indices.spi,
                        da_precip_groupby)

# unstack the array back into original dimensions
da_spi = da_spi.unstack('point')

起作用的NetCDF如下所示:

>>> import xarray as xr
>>> ds_good = xr.open_dataset("good.nc")
>>> ds_good
<xarray.Dataset>
Dimensions:  (lat: 38, lon: 87, time: 1466)
Coordinates:
  * lat      (lat) float32 24.5625 25.229166 25.895834 ... 48.5625 49.229168
  * lon      (lon) float32 -124.6875 -124.020836 ... -68.020836 -67.354164
  * time     (time) datetime64[ns] 1895-01-01 1895-02-01 ... 2017-02-01
Data variables:
    prcp     (lat, lon, time) float32 ...
Attributes:
    Conventions:               CF-1.6, ACDD-1.3
    ncei_template_version:     NCEI_NetCDF_Grid_Template_v2.0
    title:                     nClimGrid
    naming_authority:          gov.noaa.ncei
    standard_name_vocabulary:  Standard Name Table v35
    institution:               National Centers for Environmental Information...
    geospatial_lat_min:        24.5625
    geospatial_lat_max:        49.354168
    geospatial_lon_min:        -124.6875
    geospatial_lon_max:        -67.020836
    geospatial_lat_units:      degrees_north
    geospatial_lon_units:      degrees_east
    NCO:                       4.7.1
    nco_openmp_thread_number:  1
>>> ds_good.prcp
<xarray.DataArray 'prcp' (lat: 38, lon: 87, time: 1466)>
[4846596 values with dtype=float32]
Coordinates:
  * lat      (lat) float32 24.5625 25.229166 25.895834 ... 48.5625 49.229168
  * lon      (lon) float32 -124.6875 -124.020836 ... -68.020836 -67.354164
  * time     (time) datetime64[ns] 1895-01-01 1895-02-01 ... 2017-02-01
Attributes:
    valid_min:      0.0
    units:          millimeter
    valid_max:      2000.0
    standard_name:  precipitation_amount
    long_name:      Precipitation, monthly total

失败的NetCDF看起来像这样:

>>> ds_bad = xr.open_dataset("bad.nc")   >>> ds_bad
<xarray.Dataset>
Dimensions:  (lat: 38, lon: 87, time: 1483)
Coordinates:
  * lat      (lat) float32 49.3542 48.687534 48.020866 ... 25.3542 24.687532
  * lon      (lon) float32 -124.6875 -124.020836 ... -68.020836 -67.354164
  * time     (time) datetime64[ns] 1895-01-01 1895-02-01 ... 2018-07-01
Data variables:
    prcp     (lat, lon, time) float32 ...
Attributes:
    date_created:              2018-02-15 10:29:25.485927
    date_modified:             2018-02-15 10:29:25.486042
    Conventions:               CF-1.6, ACDD-1.3
    ncei_template_version:     NCEI_NetCDF_Grid_Template_v2.0
    title:                     nClimGrid
    naming_authority:          gov.noaa.ncei
    standard_name_vocabulary:  Standard Name Table v35
    institution:               National Centers for Environmental Information...
    geospatial_lat_min:        24.562532
    geospatial_lat_max:        49.3542
    geospatial_lon_min:        -124.6875
    geospatial_lon_max:        -67.020836
    geospatial_lat_units:      degrees_north
    geospatial_lon_units:      degrees_east
>>> ds_bad.prcp
<xarray.DataArray 'prcp' (lat: 38, lon: 87, time: 1483)>
[4902798 values with dtype=float32]
Coordinates:
  * lat      (lat) float32 49.3542 48.687534 48.020866 ... 25.3542 24.687532
  * lon      (lon) float32 -124.6875 -124.020836 ... -68.020836 -67.354164
  * time     (time) datetime64[ns] 1895-01-01 1895-02-01 ... 2018-07-01
Attributes:
    valid_min:      0.0
    long_name:      Precipitation, monthly total
    standard_name:  precipitation_amount
    units:          millimeter
    valid_max:      2000.0

当我对上面的第一个文件运行代码时,它可以正常工作而不会出错。使用第二个文件时,出现如下错误:

multiprocessing.pool.RemoteTraceback:
"""
Traceback (most recent call last):
  File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/multiprocessing/pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
  File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/multiprocessing/pool.py", line 44, in mapstar
    return list(map(*args))
  File "/home/paperspace/git/climate_indices/scripts/process_grid_ufunc.py", line 278, in compute_write_spi
    kwargs=args_dict)
  File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/site-packages/xarray/core/computation.py", line 974, in apply_ufunc
    return apply_groupby_ufunc(this_apply, *args)
  File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/site-packages/xarray/core/computation.py", line 432, in apply_groupby_ufunc
    applied_example, applied = peek_at(applied)
  File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/site-packages/xarray/core/utils.py", line 133, in peek_at
    peek = next(gen)
  File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/site-packages/xarray/core/computation.py", line 431, in <genexpr>
    applied = (func(*zipped_args) for zipped_args in zip(*iterators))
  File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/site-packages/xarray/core/computation.py", line 987, in apply_ufunc
    exclude_dims=exclude_dims)
  File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/site-packages/xarray/core/computation.py", line 211, in apply_dataarray_ufunc
    result_var = func(*data_vars)
  File "/home/paperspace/anaconda3/envs/climate/lib/python3.6/site-packages/xarray/core/computation.py", line 579, in apply_variable_ufunc
    .format(data.ndim, len(dims), dims))
ValueError: applied function returned data with unexpected number of dimensions: 1 vs 2, for dimensions ('time', 'point')

任何人都可以评论可能是什么问题吗?

2 个答案:

答案 0 :(得分:2)

事实证明,输入时出现问题的NetCDF文件的纬度坐标值是降序排列的。 xarray.apply_ufunc()似乎要求坐标值按升序排列,至少是为了避免此特定问题。在使用NetCDF文件作为xarray的输入之前,可以使用NCO的ncpdq命令反转违规维度的坐标值来轻松地解决此问题。

答案 1 :(得分:1)

感谢您的答复。

有时,通过以升序方式对尺寸进行排序似乎可以正确解决有关xr.apply_ufunc的问题。尽管如此,在某些时候这种机动还不够。

另一种替代解决方案是将外部用户功能将广播的坐标堆叠到一个新的维度(即:将“经度”和“纬度”维度堆叠到一个名为“ Grid_Point”的新维度中)。堆叠之后,可以对这个新维度“ Grid_Point”执行分组操作,并应用xr.apply_ufunc。

这是一个示例,该示例说明了如何从基于netcdf的每个像素的温度数据集的高斯分布(“均值”和“标准差”)中得出相应的统计矩。

import xarray as xr

# http://xarray.pydata.org/en/stable/dask.html
from scipy import stats
from dask.diagnostics import ProgressBar
import numpy as np
import warnings

def get_params_from_distribution(data, distribution='exponweib'):

    distribution = getattr(stats, distribution)

    if np.all(np.isnan(data)):

        with warnings.catch_warnings():
            warnings.filterwarnings(action="ignore")
            try:
                temp_data = distribution.rvs(1, size=10)
            except:
                try:
                    temp_data = distribution.rvs(1, 1, size=10)    

                except:
                    temp_data = distribution.rvs(1, 1, 1, size=10)  


            n_params = len(distribution.fit(temp_data))


        return data[:n_params]

    else:
        return list(distribution.fit(data))

def get_params_vectorized_from_stacked(stacked_data, distribution='exponweib', 
                                       dask='allowed',
                                       input_core_dims='time',
                                       output_core_dims = 'stat_moments',
                                       output_dtypes=[xr.core.dataset.Dataset]):

    kwargs = {'distribution': distribution}


    with ProgressBar():

        da_spi = xr.apply_ufunc(get_params_from_distribution,
                                stacked_data, 
                                exclude_dims={input_core_dims},
                                kwargs=kwargs,
                                input_core_dims=[[input_core_dims]],
                                output_core_dims=[[output_core_dims]],
                                dask=dask,
                                output_dtypes=[output_dtypes]).compute()


    return da_spi

def stack_ds(ds, dims=['lon', 'lat'], stacked_dim_name='point'):

    return ds.stack({stacked_dim_name:dims})

def main_pdf_u_function_getter(ds, 
                               dims_to_stack=['lon', 'lat'], 
                               stacked_dim_name='point', 
                               distribution_name='exponweib',
                               dask='allowed',
                               input_core_dims = 'time',
                               output_core_dims = 'stat_moments',
                               output_dtypes=[float]):

    ds_stacked = stack_ds(ds, dims_to_stack, stacked_dim_name) # observation 1

    ds_groupby = ds_stacked.groupby(stacked_dim_name) # observation 1

    results = get_params_vectorized_from_stacked(ds_groupby, 
                                                 distribution=distribution_name, 
                                                 dask=dask,
                                                 output_core_dims=output_core_dims,
                                                 input_core_dims=input_core_dims,
                                                 output_dtypes=output_dtypes)

    return results.unstack(stacked_dim_name)



if '__main__' == __name__:

    ds = xr.tutorial.open_dataset('air_temperature').sortby(['lat', 'lon', 'time'])


    R = main_pdf_u_function_getter(ds, 
                                   dask='parallelized', 
                                   dims_to_stack=['lon', 'lat'],
                                   stacked_dim_name='point',
                                   distribution_name='norm')

    print(R)
    import matplotlib.pyplot as plt


    fig, ax= plt.subplots(1,2)
    ax = ax.ravel()
    for moment in range(R.dims['stat_moments']):
        R['air'].isel({'stat_moments':moment}).plot(ax=ax[moment], cmap='viridis')

请注意,在上面的代码中,有一条注释行写为“观察1”。这些是确保整个算法正常工作的主线。在执行功能操作之前,它会在广播维度上进行叠加。

尽管有给定的解决方案(它可以工作),但我仍然不知道为什么要在xr.apply_ufunc之前进行堆栈。这是一个悬而未决的问题。

此致