theano函数的错误输入参数

时间:2015-05-27 04:50:17

标签: theano

我是theano的新手。我正在尝试实现简单的线性回归,但我的程序会抛出以下错误:

  

TypeError :(在名称为“/home/akhan/Theano-Project/uog/theano_application/linear_regression.py:36”的theano函数的错误输入参数,索引0(从0开始)','预期数组 - 喜欢对象,但找到了一个变量:也许你试图在一个(可能是共享的)变量上调用一个函数而不是数字数组?')

这是我的代码:

import theano
from theano import tensor as T
import numpy as np
import matplotlib.pyplot as plt

x_points=np.zeros((9,3),float)
x_points[:,0] = 1
x_points[:,1] = np.arange(1,10,1)
x_points[:,2] = np.arange(1,10,1) 
y_points = np.arange(3,30,3) + 1


X = T.vector('X')
Y = T.scalar('Y')

W = theano.shared(
            value=np.zeros(
                (3,1),
                dtype=theano.config.floatX
            ),
            name='W',
            borrow=True
        )

out = T.dot(X, W)
predict = theano.function(inputs=[X], outputs=out)

y = predict(X)  # y = T.dot(X, W) work fine

cost = T.mean(T.sqr(y-Y))

gradient=T.grad(cost=cost,wrt=W)

updates = [[W,W-gradient*0.01]]

train = theano.function(inputs=[X,Y], outputs=cost, updates=updates, allow_input_downcast=True)


for i in np.arange(x_points.shape[0]):
    print "iteration" + str(i)
    train(x_points[i,:],y_points[i])

sample = np.arange(x_points.shape[0])+1
y_p = np.dot(x_points,W.get_value())
plt.plot(sample,y_p,'r-',sample,y_points,'ro')
plt.show()

此错误背后的解释是什么? (没有从错误消息中获得)。在此先感谢。

1 个答案:

答案 0 :(得分:5)

Theano在定义计算图和使用这样的图计算结果的函数之间存在重要区别。

定义时

out = T.dot(X, W)
predict = theano.function(inputs=[X], outputs=out)

您首先根据outX设置W的计算图表。请注意,X是纯粹的符号变量,它没有任何值,但out的定义告诉Theano,"给定X的值,是如何计算out"。

另一方面,predicttheano.function,它采用out的计算图和X的实际数值来生成数字输出。调用它时传入theano.function的内容始终必须具有实际数值。所以做

是没有意义的
y = predict(X)

因为X是一个符号变量而且没有实际值。

您希望这样做的原因是您可以使用y来进一步构建计算图。但是没有必要使用predictpredict的计算图已经在前面定义的变量out中可用。因此,您只需删除定义y的行,然后将费用定义为

cost = T.mean(T.sqr(out - Y))

其余代码将不加修改地工作。