我想基于特定组对xarray数据集进行下采样,因此我使用groupby
选择组,然后在每组中采集10%的样本。我正在使用下面的代码,但我得到IndexError: index 1330 is out of bounds for axis 0 with size 1330
,它告诉我我的函数返回一个空数组,但subset
肯定有非零维度。
我正在使用squeeze=True
我认为可以根据GroupBy documentation允许新尺寸,但这没有用,所以我将其更改为squeeze=False
。
你知道可能发生的事吗? 谢谢!
# Set random seed for reproducibility
np.random.seed(0)
def select_random_cell_subset(x):
size = int(0.1 * len(x.cell))
random_cells = sorted(np.random.choice(x.cell, size=size, replace=False))
print('number of random cells:', len(random_cells))
print('\tsome random cells:', random_cells[:5])
subset = x.sel(cell=random_cells)
print('subset:', subset)
return subset
# squeeze=False because the final dataset is smaller than the original
ds_subset = ds.groupby('group', squeeze=True).apply(select_random_cell_subset)
ds_subset
这是错误:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-44-39c7803e9e40> in <module>()
12
13 # squeeze=False because the final dataset is smaller than the original
---> 14 ds_subset = ds.groupby('group', squeeze=True).apply(select_random_cell_subset)
15 ds_subset
~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/groupby.py in apply(self, func, **kwargs)
615 kwargs.pop('shortcut', None) # ignore shortcut if set (for now)
616 applied = (func(ds, **kwargs) for ds in self._iter_grouped())
--> 617 return self._combine(applied)
618
619 def _combine(self, applied):
~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/groupby.py in _combine(self, applied)
622 coord, dim, positions = self._infer_concat_args(applied_example)
623 combined = concat(applied, dim)
--> 624 combined = _maybe_reorder(combined, dim, positions)
625 if coord is not None:
626 combined[coord.name] = coord
~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/groupby.py in _maybe_reorder(xarray_obj, dim, positions)
443 return xarray_obj
444 else:
--> 445 return xarray_obj[{dim: order}]
446
447
~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/dataset.py in __getitem__(self, key)
716 """
717 if utils.is_dict_like(key):
--> 718 return self.isel(**key)
719
720 if hashable(key):
~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/dataset.py in isel(self, drop, **indexers)
1141 for name, var in iteritems(self._variables):
1142 var_indexers = dict((k, v) for k, v in indexers if k in var.dims)
-> 1143 new_var = var.isel(**var_indexers)
1144 if not (drop and name in var_indexers):
1145 variables[name] = new_var
~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/variable.py in isel(self, **indexers)
568 if dim in indexers:
569 key[i] = indexers[dim]
--> 570 return self[tuple(key)]
571
572 def squeeze(self, dim=None):
~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/variable.py in __getitem__(self, key)
398 dims = tuple(dim for k, dim in zip(key, self.dims)
399 if not isinstance(k, integer_types))
--> 400 values = self._indexable_data[key]
401 # orthogonal indexing should ensure the dimensionality is consistent
402 if hasattr(values, 'ndim'):
~/anaconda3/envs/cshl-sca-2017/lib/python3.6/site-packages/xarray/core/indexing.py in __getitem__(self, key)
476 def __getitem__(self, key):
477 key = self._convert_key(key)
--> 478 return self._ensure_ndarray(self.array[key])
479
480 def __setitem__(self, key, value):
IndexError: index 1330 is out of bounds for axis 0 with size 1330
答案 0 :(得分:4)
这是一件非常明智的事情,但遗憾的是它还没有奏效。 Xarray使用一些启发式方法来确定apply
操作是reduce
还是transform
类型,在这种情况下,我们错误地将分组操作识别为“转换”,因为输出重用了原始尺寸名称。我只是filed a bug report但不幸的是,xarray的修复程序会有所帮助。
最简单的解决方法可能是让应用函数返回一个布尔DataArray,指示要保留的位置。然后,您可以使用索引操作从原始对象中进行选择。
答案 1 :(得分:3)
以下是我实施它的方式。正如@shoyer上面建议的那样,我为每个组返回了一个布尔xarray.DataArray
,然后使用该布尔来对我的数据进行子集化。
# Set random seed for reproducibility
np.random.seed(0)
def select_random_cell_subset(x, threshold=0.1):
random_bools = xr.DataArray(np.random.uniform(size=len(x.cell)) <= threshold,
coords=dict(cell=x.cell))
return random_bools
subset_bools = ds.groupby('group',).apply(select_random_cell_subset,
threshold=0.1)
ds_subset = ds.sel(cell=subset_bools)