目前我的代码如下:
class GradientDescent:
def __init__(self, shape, log_space=False):
"""
* log_space puts the weights in log space automatically.
"""
super().__init__()
self.w = np.ones(shape)
if log_space:
self.true_weight_proxy = GradientDescent.LogTrueWeightProxy(self)
else:
self.true_weight_proxy = self.w
class LogTrueWeightProxy:
def __init__(self, gradient_descent):
self.gradient_descent = gradient_descent
@property
def shape(self):
return self.gradient_descent.w.shape
def __getitem__(self, indices):
return np.exp(self.gradient_descent.w[indices])
def __setitem__(self, indices, new_weight):
self.gradient_descent.w[indices] = np.log(new_weight)
def __array__(self, dtype=None):
retval = np.exp(self.gradient_descent.w)
if dtype is not None:
retval = retval.astype(dtype, copy=False)
return retval
@property
def true_weight(self):
return self.true_weight_proxy
这允许我做以下事情:
weight = x.true_weight
weight = x.true_weight[2, 3]
x.true_weight[2, 3] = new_weight
如果对象的log_space
标志已打开,则权重将放入日志空间。
我的问题是我不喜欢一个接一个地暴露numpy方法来访问例如x.true_weight.shape
,继承不适合代理对象。是否有一个numpy mixin类提供方法和属性,如dtype
,shape
等,以便我的代理对象看起来像一个numpy数组用于所有目的。