说我需要max_(a', m') f(a, m, e, m', a')
,并且我已使用网格f
近似V1
。这是一个形状为(nA, nM, nE, nM, nA)
的numpy矩阵(最后附上)。
我想首先进行插值,然后进行最大化。以下是我当前的代码(我粘贴代码以在最后重新创建Grid
):
# takes grid indices (first three dimensions) idx and interpolates on V
def interpolateV(idx, V, Grid):
from scipy.interpolate import interp2d
f = interp2d(Grid.mGrid, Grid.aGrid, V[idx])
return f
# (somewhere else:)
s2 = (Grid.nM, Grid.nA, Grid.nE)
v1Max = np.empty(s2)
v1ArgMaxA = np.empty(s2)
v1ArgMaxM = np.empty(s2)
from scipy import optimize
for idx in np.ndindex(V1[..., 0,0].shape):
V1i = interpolateV(idx, V1, Grid)
x, f, d = optimize.fmin_l_bfgs_b(lambda x: -V1i(x[0], x[1]), np.array([1, 1]), bounds=[(Grid.aMin, Grid.aMax), (Grid.mMin, Grid.mMax)], approx_grad=True)
v1Max[idx] = f
v1ArgMaxA[idx], v1ArgMaxM[idx] = x
# let's compare with standard grid-wise optimization (without interpolation):
temp = V1.max(axis=-1)
# maximize over m
v1Max = temp.max(axis=-1)
# now max over a, given optimal m
v1ArgMaxAGrid = temp.argmax(axis=-1)
到目前为止,这么好。但是,插值最大化的值是偏离的:
In[51]: v1ArgMaxAGrid[:,:,0]
Out[51]:
array([[0, 0, 0, 0, 2],
[0, 0, 0, 0, 2],
[0, 0, 0, 2, 2],
[0, 0, 0, 2, 3],
[0, 0, 0, 2, 3]], dtype=int64)
In[54]: Grid.aGrid[v1ArgMaxAGrid[:,:,0]]
Out[54]:
array([[ 0. , 0. , 0. , 0. , 3.5 ],
[ 0. , 0. , 0. , 0. , 3.5 ],
[ 0. , 0. , 0. , 3.5 , 3.5 ],
[ 0. , 0. , 0. , 3.5 , 5.25],
[ 0. , 0. , 0. , 3.5 , 5.25]])
In[52]: v1ArgMaxA[:,:,0]
Out[52]:
array([[ 0. , 0.75 , 2.25 , 7. , 7. ],
[ 0. , 1.5 , 4.247, 7. , 7. ],
[ 0.75 , 2.25 , 7. , 7. , 7. ],
[ 1.5 , 1.5 , 7. , 7. , 7. ],
[ 2.25 , 4.939, 7. , 7. , 7. ]])
这里发生了什么;为什么价值如此偏离?我做错了吗?
在复制粘贴后重新创建Grid
,V1
:
class Grids(object):
nE = 2
nA = 5
nM = 5
M = 3
A = 7
mMin = 0
mMax = M
aMin = 0
aMax = A
def __init__(self):
self.reset();
def reset(self):
self.mGrid = np.linspace(self.mMin, self.mMax, self.nM)
self.aGrid = np.linspace(self.aMin, self.aMax, self.nA)
self.eGrid = np.array([0.318, 3.149])
self.transitionE = np.array([[1., 0.],
[0., 1.]])
import numpy as np
Grid = Grids()
V1 = np.array([[[[[ 1.19 , 0.975, -0.371, -2.848, -6.456],
[ -1.463, -4.313, -8.294, -13.407, -19.65 ],
[ -9.888, -15.377, -21.997, -29.748, -38.63 ],
[-24.574, -32.701, -41.958, -52.347, -63.866],
[-45.562, -56.325, -68.218, -81.242, -95.397]],
[[ 64.724, 64.672, 64.567, 64.358, 54.127],
[ 64.247, 63.964, 53.759, 52.687, 50.487],
[ 53.526, 52.078, 49.501, 45.799, 40.969],
[ 48.389, 44.307, 39.105, 32.769, 25.314],
[ 37.062, 30.347, 22.52 , 13.553, 3.47 ]]],
[[[ 12.624, 12.704, 12.591, 2.602, 1.618],
[ 2.237, 2.011, 0.655, -1.832, -5.45 ],
[ -0.064, -2.928, -6.923, -12.049, -18.306],
[ -8.624, -14.126, -20.759, -28.522, -37.416],
[-23.488, -31.625, -40.894, -51.293, -62.822]],
[[ 65.686, 65.695, 65.679, 65.631, 65.537],
[ 65.401, 65.342, 65.23 , 65.014, 54.778],
[ 65.174, 64.881, 54.667, 53.59 , 51.385],
[ 54.43 , 52.973, 50.396, 46.685, 41.855],
[ 49.228, 45.138, 39.936, 33.594, 26.136]]],
[[[ 13.681, 13.872, 14.024, 14.117, 14.093],
[ 13.671, 13.74 , 13.617, 3.617, 2.624],
[ 3.636, 3.397, 2.027, -0.474, -4.106],
[ 1.2 , -1.677, -5.684, -10.823, -17.092],
[ -7.538, -13.051, -19.694, -27.468, -36.373]],
[[ 66.553, 66.597, 66.623, 66.631, 66.614],
[ 66.362, 66.364, 66.342, 66.287, 66.188],
[ 66.327, 66.259, 66.138, 65.917, 55.676],
[ 66.077, 65.776, 55.562, 54.476, 52.271],
[ 55.269, 53.804, 51.227, 47.51 , 42.677]]],
[[[ 14.6 , 14.839, 15.054, 15.242, 15.394],
[ 14.728, 14.909, 15.05 , 15.133, 15.098],
[ 15.07 , 15.126, 14.989, 4.975, 3.968],
[ 4.9 , 4.648, 3.265, 0.752, -2.892],
[ 2.286, -0.601, -4.62 , -9.769, -16.048]],
[[ 67.36 , 67.427, 67.481, 67.52 , 67.543],
[ 67.229, 67.266, 67.286, 67.287, 67.265],
[ 67.288, 67.281, 67.25 , 67.19 , 67.085],
[ 67.231, 67.154, 67.033, 66.803, 56.562],
[ 66.917, 66.607, 56.393, 55.301, 53.093]]],
[[[ 15.442, 15.71 , 15.96 , 16.191, 16.4 ],
[ 15.647, 15.875, 16.08 , 16.257, 16.399],
[ 16.128, 16.294, 16.422, 16.491, 16.443],
[ 16.334, 16.377, 16.227, 6.201, 5.182],
[ 5.986, 5.723, 4.33 , 1.806, -1.849]],
[[ 68.123, 68.207, 68.28 , 68.342, 68.391],
[ 68.036, 68.096, 68.143, 68.176, 68.194],
[ 68.155, 68.183, 68.195, 68.19 , 68.163],
[ 68.192, 68.176, 68.145, 68.076, 67.971],
[ 68.07 , 67.984, 67.864, 67.628, 57.384]]]],
[[[[ 11.877, 1.81 , 1.59 , 0.238, -2.246],
[ 0.873, -0.853, -3.709, -7.696, -12.814],
[ -4.928, -9.292, -14.787, -21.413, -29.17 ],
[-16.988, -23.99 , -32.123, -41.386, -51.78 ],
[-35.352, -44.989, -55.758, -67.657, -80.686]],
[[ 65.151, 65.131, 65.075, 64.966, 64.753],
[ 64.779, 64.647, 64.36 , 54.151, 53.076],
[ 54.24 , 53.917, 52.465, 49.888, 46.183],
[ 51.728, 48.771, 44.694, 39.483, 33.153],
[ 43.026, 37.436, 30.734, 22.892, 13.934]]],
[[[ 13.101, 13.245, 13.318, 13.2 , 3.204],
[ 12.924, 2.847, 2.616, 1.253, -1.24 ],
[ 2.272, 0.533, -2.337, -6.338, -11.47 ],
[ -3.664, -8.041, -13.548, -20.187, -27.956],
[-15.902, -22.915, -31.058, -40.332, -50.737]],
[[ 66.066, 66.092, 66.098, 66.078, 66.025],
[ 65.827, 65.8 , 65.738, 65.622, 65.403],
[ 65.706, 65.564, 65.268, 55.054, 53.974],
[ 55.144, 54.812, 53.36 , 50.774, 47.069],
[ 52.567, 49.602, 45.525, 40.308, 33.975]]],
[[[ 14.087, 14.302, 14.487, 14.632, 14.72 ],
[ 14.148, 14.281, 14.344, 14.216, 4.21 ],
[ 14.323, 4.232, 3.987, 2.611, 0.104],
[ 3.536, 1.784, -1.099, -5.112, -10.256],
[ -2.578, -6.965, -12.484, -19.133, -26.912]],
[[ 66.905, 66.96 , 66.999, 67.022, 67.025],
[ 66.742, 66.762, 66.76 , 66.734, 66.676],
[ 66.754, 66.718, 66.646, 66.525, 66.301],
[ 66.609, 66.459, 66.163, 55.94 , 54.86 ],
[ 55.983, 55.643, 54.191, 51.599, 47.891]]],
[[[ 14.969, 15.221, 15.454, 15.663, 15.844],
[ 15.134, 15.338, 15.513, 15.648, 15.725],
[ 15.548, 15.667, 15.716, 15.574, 5.554],
[ 15.587, 5.483, 5.226, 3.837, 1.318],
[ 4.622, 2.859, -0.034, -4.058, -9.213]],
[[ 67.691, 67.766, 67.829, 67.879, 67.915],
[ 67.581, 67.629, 67.662, 67.678, 67.676],
[ 67.669, 67.679, 67.669, 67.637, 67.574],
[ 67.658, 67.612, 67.541, 67.411, 67.187],
[ 67.449, 67.29 , 66.994, 56.765, 55.682]]],
[[[ 15.786, 16.063, 16.324, 16.569, 16.794],
[ 16.016, 16.258, 16.48 , 16.679, 16.85 ],
[ 16.533, 16.724, 16.884, 17.006, 17.07 ],
[ 16.811, 16.918, 16.954, 16.8 , 6.768],
[ 16.673, 6.559, 6.29 , 4.891, 2.362]],
[[ 68.439, 68.529, 68.61 , 68.679, 68.737],
[ 68.368, 68.436, 68.492, 68.535, 68.566],
[ 68.507, 68.546, 68.571, 68.581, 68.574],
[ 68.572, 68.573, 68.563, 68.523, 68.46 ],
[ 68.497, 68.443, 68.372, 68.236, 68.009]]]],
[[[[ 12.453, 12.498, 2.425, 2.198, 0.84 ],
[ 2.083, 1.483, -0.248, -3.111, -7.104],
[ -1.092, -4.331, -8.701, -14.202, -20.834],
[-10.528, -16.405, -23.412, -31.551, -40.82 ],
[-26.266, -34.779, -44.422, -55.196, -67.101]],
[[ 65.555, 65.558, 65.534, 65.474, 65.36 ],
[ 65.252, 65.179, 65.043, 64.752, 54.54 ],
[ 64.974, 54.631, 54.304, 52.852, 50.272],
[ 53.942, 52.11 , 49.158, 45.072, 39.867],
[ 47.865, 43.4 , 37.823, 31.106, 23.273]]],
[[[ 13.541, 13.722, 13.859, 13.927, 13.802],
[ 13.5 , 13.534, 3.451, 3.214, 1.846],
[ 3.482, 2.868, 1.123, -1.753, -5.76 ],
[ 0.172, -3.08 , -7.463, -12.976, -19.62 ],
[ -9.442, -15.329, -22.348, -30.497, -39.776]],
[[ 66.433, 66.473, 66.495, 66.496, 66.472],
[ 66.231, 66.227, 66.196, 66.13 , 66.011],
[ 66.178, 66.096, 65.952, 65.655, 55.438],
[ 65.877, 55.526, 55.199, 53.738, 51.158],
[ 54.781, 52.941, 49.989, 45.897, 40.689]]],
[[[ 14.475, 14.708, 14.917, 15.095, 15.235],
[ 14.588, 14.759, 14.886, 14.943, 14.808],
[ 14.899, 14.92 , 4.823, 4.572, 3.19 ],
[ 4.746, 4.119, 2.362, -0.527, -4.546],
[ 1.258, -2.005, -6.398, -11.922, -18.577]],
[[ 67.247, 67.312, 67.362, 67.398, 67.417],
[ 67.109, 67.142, 67.158, 67.152, 67.123],
[ 67.158, 67.145, 67.105, 67.033, 66.909],
[ 67.082, 66.991, 66.847, 66.541, 56.324],
[ 66.717, 56.357, 56.03 , 54.563, 51.98 ]]],
[[[ 15.325, 15.59 , 15.836, 16.062, 16.265],
[ 15.522, 15.744, 15.943, 16.111, 16.24 ],
[ 15.987, 16.144, 16.257, 16.301, 16.152],
[ 16.163, 16.171, 6.061, 5.798, 4.404],
[ 5.832, 5.195, 3.426, 0.527, -3.502]],
[[ 68.016, 68.098, 68.169, 68.228, 68.274],
[ 67.924, 67.981, 68.025, 68.054, 68.067],
[ 68.036, 68.059, 68.066, 68.056, 68.021],
[ 68.061, 68.039, 68. , 67.919, 67.794],
[ 67.921, 67.822, 67.677, 67.366, 57.146]]],
[[[ 16.121, 16.406, 16.678, 16.933, 17.171],
[ 16.372, 16.626, 16.862, 17.078, 17.271],
[ 16.921, 17.13 , 17.314, 17.469, 17.585],
[ 17.251, 17.395, 17.496, 17.527, 17.366],
[ 17.249, 17.246, 7.126, 6.852, 5.447]],
[[ 68.749, 68.846, 68.932, 69.008, 69.074],
[ 68.692, 68.768, 68.832, 68.884, 68.925],
[ 68.85 , 68.898, 68.934, 68.957, 68.965],
[ 68.939, 68.954, 68.961, 68.941, 68.906],
[ 68.901, 68.87 , 68.831, 68.744, 68.617]]]],
[[[[ 12.947, 13.074, 13.112, 3.034, 2.8 ],
[ 12.693, 2.693, 2.087, 0.35 , -2.518],
[ 1.618, -0.496, -3.741, -8.117, -13.624],
[ -5.192, -9.944, -15.827, -22.84 , -30.984],
[-18.306, -25.693, -34.212, -43.861, -54.64 ]],
[[ 65.941, 65.962, 65.961, 65.932, 65.868],
[ 65.688, 65.652, 65.575, 65.435, 65.141],
[ 65.537, 65.364, 55.018, 54.691, 53.236],
[ 55.031, 54.324, 52.497, 49.536, 45.456],
[ 51.579, 48.239, 43.787, 38.195, 31.487]]],
[[[ 13.954, 14.162, 14.337, 14.468, 14.529],
[ 13.994, 14.11 , 14.139, 4.049, 3.806],
[ 14.093, 4.079, 3.459, 1.708, -1.174],
[ 2.882, 0.755, -2.502, -6.891, -12.41 ],
[ -4.106, -8.869, -14.762, -21.786, -29.941]],
[[ 66.789, 66.84 , 66.876, 66.893, 66.891],
[ 66.617, 66.631, 66.623, 66.588, 66.519],
[ 66.614, 66.569, 66.484, 66.339, 66.039],
[ 66.441, 66.259, 55.913, 55.577, 54.122],
[ 55.87 , 55.155, 53.328, 50.361, 46.278]]],
[[[ 14.847, 15.096, 15.323, 15.525, 15.698],
[ 15.001, 15.198, 15.363, 15.484, 15.535],
[ 15.394, 15.496, 15.51 , 5.407, 5.15 ],
[ 15.356, 5.33 , 4.697, 2.934, 0.04 ],
[ 3.968, 1.831, -1.438, -5.837, -11.366]],
[[ 67.582, 67.654, 67.714, 67.761, 67.792],
[ 67.465, 67.509, 67.538, 67.549, 67.542],
[ 67.543, 67.548, 67.532, 67.492, 67.417],
[ 67.518, 67.463, 67.379, 67.224, 66.925],
[ 67.28 , 67.09 , 56.744, 56.402, 54.944]]],
[[[ 15.672, 15.946, 16.204, 16.444, 16.665],
[ 15.894, 16.132, 16.349, 16.541, 16.703],
[ 16.4 , 16.584, 16.734, 16.842, 16.879],
[ 16.657, 16.746, 16.749, 6.633, 6.364],
[ 16.443, 6.405, 5.762, 3.988, 1.083]],
[[ 68.334, 68.423, 68.501, 68.568, 68.623],
[ 68.258, 68.324, 68.377, 68.417, 68.443],
[ 68.391, 68.426, 68.447, 68.453, 68.439],
[ 68.447, 68.443, 68.427, 68.378, 68.302],
[ 68.357, 68.294, 68.209, 68.049, 67.747]]],
[[[ 16.448, 16.742, 17.021, 17.286, 17.535],
[ 16.719, 16.983, 17.23 , 17.46 , 17.67 ],
[ 17.294, 17.517, 17.72 , 17.899, 18.048],
[ 17.664, 17.835, 17.973, 18.068, 18.093],
[ 17.744, 17.822, 17.813, 7.687, 7.408]],
[[ 69.055, 69.156, 69.248, 69.331, 69.403],
[ 69.01 , 69.092, 69.163, 69.224, 69.273],
[ 69.184, 69.241, 69.286, 69.32 , 69.341],
[ 69.295, 69.321, 69.341, 69.339, 69.325],
[ 69.286, 69.274, 69.257, 69.203, 69.125]]]],
[[[[ 13.398, 13.568, 13.688, 13.721, 3.636],
[ 13.32 , 13.303, 3.298, 2.685, 0.942],
[ 3.204, 2.215, 0.095, -3.156, -7.538],
[ -0.982, -4.609, -9.366, -15.255, -22.274],
[-11.47 , -17.733, -25.126, -33.65 , -43.305]],
[[ 66.312, 66.348, 66.364, 66.359, 66.327],
[ 66.099, 66.088, 66.048, 65.967, 65.825],
[ 66.025, 65.928, 65.752, 55.405, 55.075],
[ 65.656, 55.413, 54.711, 52.875, 49.92 ],
[ 54.168, 51.953, 48.626, 44.159, 38.576]]],
[[[ 14.347, 14.575, 14.776, 14.945, 15.07 ],
[ 14.445, 14.605, 14.714, 14.737, 4.642],
[ 14.72 , 14.689, 4.669, 4.043, 2.286],
[ 4.468, 3.466, 1.333, -1.93 , -6.324],
[ 0.104, -3.533, -8.302, -14.201, -21.23 ]],
[[ 67.134, 67.195, 67.243, 67.274, 67.288],
[ 66.988, 67.017, 67.027, 67.015, 66.978],
[ 67.025, 67.005, 66.956, 66.871, 66.722],
[ 66.929, 66.822, 66.646, 56.291, 55.961],
[ 66.496, 56.244, 55.542, 53.7 , 50.742]]],
[[[ 15.208, 15.468, 15.71 , 15.931, 16.128],
[ 15.394, 15.611, 15.802, 15.961, 16.076],
[ 15.844, 15.99 , 16.086, 16.095, 5.986],
[ 15.983, 15.94 , 5.908, 5.269, 3.5 ],
[ 5.554, 4.541, 2.398, -0.876, -5.281]],
[[ 67.909, 67.988, 68.057, 68.113, 68.155],
[ 67.81 , 67.865, 67.905, 67.93 , 67.939],
[ 67.915, 67.934, 67.936, 67.919, 67.875],
[ 67.929, 67.9 , 67.851, 67.756, 67.608],
[ 67.768, 67.653, 67.477, 57.116, 56.783]]],
[[[ 16.01 , 16.293, 16.561, 16.813, 17.047],
[ 16.255, 16.505, 16.736, 16.947, 17.133],
[ 16.794, 16.997, 17.174, 17.319, 17.42 ],
[ 17.108, 17.241, 17.324, 17.321, 7.2 ],
[ 17.07 , 17.015, 6.972, 6.323, 4.544]],
[[ 68.647, 68.741, 68.825, 68.899, 68.962],
[ 68.585, 68.658, 68.719, 68.769, 68.806],
[ 68.737, 68.782, 68.814, 68.834, 68.836],
[ 68.818, 68.829, 68.83 , 68.804, 68.761],
[ 68.768, 68.731, 68.682, 68.581, 68.431]]],
[[[ 16.769, 17.069, 17.356, 17.63 , 17.888],
[ 17.057, 17.329, 17.587, 17.829, 18.053],
[ 17.654, 17.89 , 18.108, 18.305, 18.477],
[ 18.057, 18.248, 18.412, 18.545, 18.634],
[ 18.194, 18.316, 18.389, 18.375, 8.243]],
[[ 69.355, 69.461, 69.559, 69.647, 69.725],
[ 69.323, 69.41 , 69.488, 69.555, 69.613],
[ 69.511, 69.575, 69.628, 69.672, 69.704],
[ 69.64 , 69.677, 69.708, 69.719, 69.722],
[ 69.658, 69.66 , 69.661, 69.63 , 69.584]]]]])