如何制作numpy数组代理对象?有混合课吗?

时间:2015-03-21 12:45:56

标签: python numpy proxy-classes

目前我的代码如下:

class GradientDescent:

    def __init__(self, shape, log_space=False):
        """
        * log_space puts the weights in log space automatically.
        """
        super().__init__()
        self.w = np.ones(shape)

        if log_space:
            self.true_weight_proxy = GradientDescent.LogTrueWeightProxy(self)
        else:
            self.true_weight_proxy = self.w

    class LogTrueWeightProxy:

        def __init__(self, gradient_descent):
            self.gradient_descent = gradient_descent

        @property
        def shape(self):
            return self.gradient_descent.w.shape

        def __getitem__(self, indices):
            return np.exp(self.gradient_descent.w[indices])

        def __setitem__(self, indices, new_weight):
            self.gradient_descent.w[indices] = np.log(new_weight)

        def __array__(self, dtype=None):
            retval = np.exp(self.gradient_descent.w)
            if dtype is not None:
                retval = retval.astype(dtype, copy=False)
            return retval

    @property
    def true_weight(self):
        return self.true_weight_proxy

这允许我做以下事情:

weight = x.true_weight
weight = x.true_weight[2, 3]
x.true_weight[2, 3] = new_weight

如果对象的log_space标志已打开,则权重将放入日志空间。

我的问题是我不喜欢一个接一个地暴露numpy方法来访问例如x.true_weight.shape,继承不适合代理对象。是否有一个numpy mixin类提供方法和属性,如dtypeshape等,以便我的代理对象看起来像一个numpy数组用于所有目的。

0 个答案:

没有答案