如何解决Theano渐变中断开的输入错误?

时间:2017-04-10 06:08:26

标签: python theano

我试图通过最小化成本来学习使用渐变的权重矩阵。在这样做的时候,我输入错误了。

import theano
import numpy as np

data = np.random.rand(3,3)
weight = np.random.rand(3,3)
target = data

x = theano.tensor.fmatrix('x')
y = theano.tensor.fmatrix('y')
w = theano.shared(weight,'w')


def getY(x,w):
    print "computing y"
    y = x*w
    return y

cost = (x-y).sum()
gradients = theano.tensor.grad(cost, [w])
W_updated = w - (0.1 * gradients[0])

f = theano.function([x], [x,y,w,cost], [(y,getY(x,w)),(w, 
W_updated)],allow_input_downcast=True)

for i in xrange(2):
    output = f(data)

print w.get_value()

0 个答案:

没有答案