请尝试提供参数化解决方案(有三个以上的替代方案)。
我有一个包含Beta值的字典:
{'B_X1': 2.0, 'B_X2': -3.0}
此数据框:
X1_123 X1_456 X1_789 X2_123 X2_456 X2_789
6.75 4.69 9.59 5.52 9.69 7.40
7.46 4.94 3.01 1.78 1.38 4.68
2.05 7.30 4.08 7.02 8.24 8.49
5.60 7.88 8.11 5.98 4.60 1.39
1.80 8.28 9.16 7.34 7.69 6.16
3.73 6.93 8.93 2.58 3.48 6.04
8.06 8.88 7.06 6.76 4.68 7.82
5.00 7.29 5.86 3.92 5.67 4.10
2.49 2.55 4.66 7.15 6.26 7.87
1.50 3.35 5.70 9.86 4.83 1.17
8.19 7.72 9.56 6.61 4.15 3.64
2.43 9.54 9.15 4.41 9.18 7.85
2.71 3.24 4.56 6.22 7.89 9.93
5.96 4.34 5.26 8.63 9.81 9.40
123
,456
和789
是替代方案。
j
,k
和s
是上述选择。
预期结果:
X1_123 X1_456 X1_789 X2_123 X2_456 X2_789 P_123 P_456 P_789
6.75 4.69 9.59 5.52 9.69 7.40 0.490 0.000 0.510
7.46 4.94 3.01 1.78 1.38 4.68 0.979 0.021 0.000
2.05 7.30 4.08 7.02 8.24 8.49 0.001 0.998 0.001
5.60 7.88 8.11 5.98 4.60 1.39 0.000 0.000 1.000
1.80 8.28 9.16 7.34 7.69 6.16 0.000 0.002 0.998
3.73 6.93 8.93 2.58 3.48 6.04 0.024 0.952 0.024
8.06 8.88 7.06 6.76 4.68 7.82 0.000 1.000 0.000
5.00 7.29 5.86 3.92 5.67 4.10 0.210 0.107 0.683
2.49 2.55 4.66 7.15 6.26 7.87 0.038 0.623 0.339
1.50 3.35 5.70 9.86 4.83 1.17 0.000 0.000 1.000
8.19 7.72 9.56 6.61 4.15 3.64 0.000 0.005 0.995
2.43 9.54 9.15 4.41 9.18 7.85 0.041 0.037 0.922
2.71 3.24 4.56 6.22 7.89 9.93 0.981 0.019 0.001
5.96 4.34 5.26 8.63 9.81 9.40 0.975 0.001 0.024
每行的概率总和应为1。
请尝试提供参数化解决方案(有三个以上的替代方案)。
每个替代方案具有恒定的预期结果:
{'B_X1': 2.0, 'B_X2': -3.0, 'B_123': 0.1, 'B_456': 0.2, 'B_789': 0.3}
X1_123 X1_456 X1_789 X2_123 X2_456 X2_789 P_123 P_456 P_789
6.75 4.69 9.59 5.52 9.69 7.40 0.440 0.000 0.560
7.46 4.94 3.01 1.78 1.38 4.68 0.977 0.023 0.000
2.05 7.30 4.08 7.02 8.24 8.49 0.001 0.998 0.001
5.60 7.88 8.11 5.98 4.60 1.39 0.000 0.000 1.000
1.80 8.28 9.16 7.34 7.69 6.16 0.000 0.002 0.998
3.73 6.93 8.93 2.58 3.48 6.04 0.021 0.952 0.027
8.06 8.88 7.06 6.76 4.68 7.82 0.000 1.000 0.000
5.00 7.29 5.86 3.92 5.67 4.10 0.180 0.102 0.717
2.49 2.55 4.66 7.15 6.26 7.87 0.034 0.604 0.363
1.50 3.35 5.70 9.86 4.83 1.17 0.000 0.000 1.000
8.19 7.72 9.56 6.61 4.15 3.64 0.000 0.005 0.995
2.43 9.54 9.15 4.41 9.18 7.85 0.034 0.034 0.932
2.71 3.24 4.56 6.22 7.89 9.93 0.978 0.021 0.001
5.96 4.34 5.26 8.63 9.81 9.40 0.970 0.001 0.029
答案 0 :(得分:2)
IIUC:
将列转换为MultiIndex
df = df.set_axis(df.columns.str.split('_', expand=True), axis=1, inplace=False)
并定义您的B
,以使键与df
中的前缀匹配
B = {'X1': 2.0, 'X2': -3.0}
然后
def f(b, x):
return np.exp((b * x).sum(1))
parts = f(B, df.stack()).unstack()
preds = parts.div(parts.sum(1), axis=0)
df.join(pd.concat({'P': preds}, axis=1).round(3)).pipe(
lambda d: d.set_axis(map('_'.join, d.columns), axis=1, inplace=False)
)
X1_123 X1_456 X1_789 X2_123 X2_456 X2_789 P_123 P_456 P_789
0 6.75 4.69 9.59 5.52 9.69 7.40 0.490 0.000 0.510
1 7.46 4.94 3.01 1.78 1.38 4.68 0.979 0.021 0.000
2 2.05 7.30 4.08 7.02 8.24 8.49 0.001 0.998 0.001
3 5.60 7.88 8.11 5.98 4.60 1.39 0.000 0.000 1.000
4 1.80 8.28 9.16 7.34 7.69 6.16 0.000 0.002 0.998
5 3.73 6.93 8.93 2.58 3.48 6.04 0.024 0.952 0.024
6 8.06 8.88 7.06 6.76 4.68 7.82 0.000 1.000 0.000
7 5.00 7.29 5.86 3.92 5.67 4.10 0.210 0.107 0.683
8 2.49 2.55 4.66 7.15 6.26 7.87 0.038 0.623 0.339
9 1.50 3.35 5.70 9.86 4.83 1.17 0.000 0.000 1.000
10 8.19 7.72 9.56 6.61 4.15 3.64 0.000 0.005 0.995
11 2.43 9.54 9.15 4.41 9.18 7.85 0.041 0.037 0.922
12 2.71 3.24 4.56 6.22 7.89 9.93 0.981 0.019 0.001
13 5.96 4.34 5.26 8.63 9.81 9.40 0.975 0.001 0.024
包装成一个漂亮的功能
def f(df, b):
d = df.set_axis(df.columns.str.split('_', expand=True), axis=1, inplace=False)
parts = np.exp(d.stack().mul(b).sum(1).unstack())
preds = pd.concat({'P': parts.div(parts.sum(1), axis=0)}, axis=1).round(3)
d = d.join(preds)
d.columns = list(map('_'.join, d.columns))
return d
f(df, B)
X1_123 X1_456 X1_789 X2_123 X2_456 X2_789 P_123 P_456 P_789
0 6.75 4.69 9.59 5.52 9.69 7.40 0.490 0.000 0.510
1 7.46 4.94 3.01 1.78 1.38 4.68 0.979 0.021 0.000
2 2.05 7.30 4.08 7.02 8.24 8.49 0.001 0.998 0.001
3 5.60 7.88 8.11 5.98 4.60 1.39 0.000 0.000 1.000
4 1.80 8.28 9.16 7.34 7.69 6.16 0.000 0.002 0.998
5 3.73 6.93 8.93 2.58 3.48 6.04 0.024 0.952 0.024
6 8.06 8.88 7.06 6.76 4.68 7.82 0.000 1.000 0.000
7 5.00 7.29 5.86 3.92 5.67 4.10 0.210 0.107 0.683
8 2.49 2.55 4.66 7.15 6.26 7.87 0.038 0.623 0.339
9 1.50 3.35 5.70 9.86 4.83 1.17 0.000 0.000 1.000
10 8.19 7.72 9.56 6.61 4.15 3.64 0.000 0.005 0.995
11 2.43 9.54 9.15 4.41 9.18 7.85 0.041 0.037 0.922
12 2.71 3.24 4.56 6.22 7.89 9.93 0.981 0.019 0.001
13 5.96 4.34 5.26 8.63 9.81 9.40 0.975 0.001 0.024