我在弄清楚如何使用xarray DataArrays和DataSet以及执行代数运算时遇到了麻烦;特别是当维度具有不同级别并且我的多维数据集具有不同粒度时。如果有人可以向我建议一些文档或给我一些建议,我将不胜感激。
在下面的示例中,我试图计算父母(PFS)下每个孩子(SKU)的贡献。我发现要获取正确的值,我需要将多维数据集切片转换为熊猫数据框。否则,Xarray将复制我正在使用的尺寸。
import pandas as pd
import numpy as np
import xarray as xr
from itertools import product
# Create hierachies
usage_type_entities = (('Regular',), ('Sample',),
('Tender',), ('Clinic Trial',))
usage_type_tree = pd.MultiIndex.from_tuples(
usage_type_entities, names=('Usage_Type',))
product_tree_hierarchy = (("PF1", "PFS1", "SKU1"),
("PF1", "PFS1", "SKU2"),
("PF1", "PFS2", "SKU3"),
("PF1", "PFS2", "SKU4"),
("PF2", "PFS3", "SKU5"))
product_tree_entities = ("PF", "PFS", "SKU")
product_tree = pd.MultiIndex.from_tuples(product_tree_hierarchy,
names=product_tree_entities)
market_tree_hierarchy = (("Group1", "Region1", "Market1"),
("Group1", "Region1", "Market2"),
("Group1", "Region2", "Market3"),
("Group1", "Region2", "Market4"),
("Group2", "Region3", "Market5"))
market_tree_entities = ("Groups", "Regions", "Markets")
market_tree = pd.MultiIndex.from_tuples(market_tree_hierarchy,
names=market_tree_entities)
time_tree_hierarchy = [(y, y+q) for y, q in product([str(2013+x) for x in range(6)],
["Q"+str(int(q)) for q in np.arange(1, 4.1, 1)])][0:22]
time_entities = ("Year", "Quarter")
time_tree = pd.MultiIndex.from_tuples(time_tree_hierarchy, names=time_entities)
# Create X-array Dataset
x1 = np.random.randint(100, size=(len(usage_type_tree), len(
product_tree), len(market_tree), len(time_tree)))
xda = xr.DataArray(x1, coords=(usage_type_tree, product_tree, market_tree, time_tree),
dims=("Usage", "Product", "Market", "Time"))
# Operations - I need to convert my slice into a pandas df to get
the right values. Converting to pandas df works ok.
market = "Market1"
ut = "Regular"
(xda.sel(Markets=market, Usage_Type=ut)[:, 0].to_pandas() /
xda.sel(Markets=market, Usage_Type=ut)[:, 0].to_pandas().groupby("PFS").sum(axis=0))
如果不将切片转换为df并将其保存为xarray数据集,则维度将重复。例如,下面的行会生成一个DatArray(Product:5,Time:22,PFS:3),当它应该只是(Product:5,Time:22)
(xda.sel(Markets=market, Usage_Type=ut)[:, 0] /
xda.sel(Markets=market, Usage_Type=ut)[:, 0].groupby("PFS").sum(axis=0))