如何在Theano中更改共享变量的值?

时间:2016-12-01 10:28:57

标签: python theano

我定义了以下类:

class test:

    def __init__(self):
        self.X = theano.tensor.dmatrix('x')
        self.W = theano.shared(value=numpy.zeros((5, 2), dtype=theano.config.floatX), name='W', borrow=True)
        self.out = theano.dot(self.X, self.W)

    def eval(self, X):
        _eval = theano.function([self.X], self.out)
        return _eval(X)

之后我尝试更改W矩阵的值并使用新值进行计算。我是通过以下方式完成的:

m = test()
W = np.transpose(np.array([[1.0, 2.0, 3.0, 4.0, 5.0], [2.0, 2.0, 3.0, 3.0, 3.0]]))
dn.W = theano.shared(value=W, name='W', borrow=True)
dn.eval(X)

我得到的结果对应于W中设置的__init__的值(所有元素都是零)。

为什么类没有看到我在初始化后显式设置的W的新值?

1 个答案:

答案 0 :(得分:1)

您刚刚为python变量dn.W创建了一个新的共享变量,但theano的内部计算图仍然链接到旧的共享变量。

更改存储在现有共享变量中的值:

W = np.transpose(np.array([[1.0, 2.0, 3.0, 4.0, 5.0], [2.0, 2.0, 3.0, 3.0, 3.0]]))
dn.W.set_value(W))

注意如果要使用函数调用的结果来更新共享变量,更好的方法是使用updates的{​​{1}}参数。如果共享变量存储在GPU中,这将消除不必要的内存传输。