fmin_cg:因类型不匹配而得到ValueError

时间:2016-06-19 19:11:07

标签: scipy

执行以下代码会产生ValueError:

def f(theta):
    theta = theta.reshape(2, 1)
    return np.linalg.norm(x*theta -y)**2

def fprime(theta):
    theta = theta.reshape(2,1)
    return (x.T*(x*theta - y))

x = np.matrix('1, 2; 3, 4; 5, 6')
y = np.matrix('4; 2; 1')
thetainit = np.matrix('0; 0')

scipy.optimize.fmin_cg(f, np.ravel(thetainit), fprime=fprime) 

但是尺寸没问题,fmin_cg的x0参数使用np.ravel展平,如文档中所述。这里出现错误信息:

   1169     gnorm = vecnorm(gfk, ord=norm)
   1170     while (gnorm > gtol) and (k < maxiter):
-> 1171         deltak = numpy.dot(gfk, gfk)
   1172 
   1173         try:

我感谢任何帮助。

1 个答案:

答案 0 :(得分:0)

此代码有效:

def f(theta):
    theta = theta.reshape(2, 1)
    return np.linalg.norm(x*theta -y)**2
#    return np.linalg.norm(x*theta -y)**2

def fprime(theta):
    theta = theta.reshape(2,1)
    g = np.array(x.T*(x*theta - y)).flatten()
    return g

x = np.matrix('1, 2; 3, 4; 5, 6')
y = np.matrix('4; 2; 1')
thetainit = np.matrix('0; 0')

scipy.optimize.fmin_cg(f, np.ravel(thetainit), fprime=fprime) 

关键是要添加'flatten()&#39;!