Matlab到Python设计矩阵函数的翻译

时间:2017-03-11 15:53:01

标签: python matlab numpy

去年我在Matlab中编写了一个线性回归程序设计矩阵的代码。它工作得很好。现在,我需要将其转换为Python并在Pycharm中运行。我已经好几天了,虽然我是Python的新手,但我在翻译中找不到任何错误,但是当代码与程序的其余部分一起运行时,我收到错误。

matlab中的代码:

function DesignMatrix = design_matrix( xTrain, M )
% This function calculates the Design Matrix for
% a M-th degree polynomial
% xTrain - training set Nx1
% M - polynomial degree 0,1,2,...

N = size(xTrain,1);
DesignMatrix = zeros(N,M+1); 
for i=1:M+1
  DesignMatrix(:,i)=xTrain.^(i-1)
end
end

和我在Python中的翻译(np代表numpy,导入):

def design_matrix(x_train,M):
    '''
    :param x_train: input vector Nx1
    :param M: polynomial degree 0,1,2,...
    :return: Design Matrix Nx(M+1) for M degree polynomial
    '''
    desm = np.zeros(shape =(len(x_train), M+1))
    for i in range(1, M+1):
        desm[:,i] = np.power(x_train, (i-1))
    return desm
    pass

错误指向此行:desm[:,i] = np.power(x_train, (i-1)),这是一个值错误。我尝试使用在线翻译ompc但它似乎已经过时,因为它对我不起作用。如果我的翻译有任何明显的错误,有人可以向我解释一下吗?我知道它是更大程序的一部分,但我要问的只是语法翻译本身。如果它是正确的,我会尝试找到任何其他错误,但到目前为止我没有提出任何错误。谢谢。

修改:追溯

ERROR: test_design_matrix (test.TestDesignMatrix)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "...\test.py", line 61, in test_design_matrix
    dm_computed = design_matrix(x_train, M)
  File "...\content.py", line 34, in design_matrix
    desm[:,i] = np.power(x_train, (i-1))
ValueError: could not broadcast input array from shape (20,1) into shape (20)

我无法更改test.py文件,它已提供给我并且无法更改,因此我只依赖于第二个错误。

从提供错误的函数test.py中摘录:

def test_design_matrix(self):
    x_train = TEST_DATA['design_matrix']['x_train']
    M = TEST_DATA['design_matrix']['M']
    dm = TEST_DATA['design_matrix']['dm']
    dm_computed = design_matrix(x_train, M)
    max_diff = np.max(np.abs(dm - dm_computed))
    self.assertAlmostEqual(max_diff, 0, 8)

3 个答案:

答案 0 :(得分:1)

你可以试试这个:

def design_matrix(x_train,M):
    '''
    :param x_train: input vector Nx1
    :param M: polynomial degree 0,1,2,...
    :return: Design Matrix Nx(M+1) for M degree polynomial
    '''
    x_train = np.asarray(x_train)
    desm = np.zeros(shape =(len(x_train), M+1))
    for i in range(0, M+1):
        desm[:,i] = np.power(x_train, i).reshape(x_train.shape[0],)
    return desm

错误来自不兼容的Numpy数组维度。 desm [:,i]具有形状(n,),但是您尝试存储的值具有形状(n,1),因此您需要将其重新整形为(n,)。另外,正如GLR所提到的,Python索引从0开始,因此您需要修改索引,并且函数执行在返回行停止,因此根本没有到达传递线。

答案 1 :(得分:0)

我看到三个错误:

  • 在Python中,索引从零开始。

  • 要为数组中的所有项目供电,可以使用**运算符。

  • pass什么都不做,因为它是在return语句之后。该功能永远不会达到这一点。

我会尝试这个:

def design_matrix(x_train,M):
    '''
    :param x_train: input vector Nx1
    :param M: polynomial degree 0,1,2,...
    :return: Design Matrix Nx(M+1) for M degree polynomial
    '''
    desm = np.zeros(shape =(len(x_train), M+1))
    for i in range(0, M+1):
        desm[:,i] = x_train.squeeze() ** (i-1)
    return desm

答案 2 :(得分:0)

您可能有兴趣知道可以使用 patsy 语言和模块为多项式回归创建正交设计矩阵。

>>> import numpy as np
>>> from patsy import dmatrices, dmatrix, demo_data, Poly
>>> data = demo_data("a", "b", "x1", "x2", "y", "z column")
>>> dmatrix('C(x1, Poly)', data)
DesignMatrix with shape (8, 8)
Columns:
['Intercept', 'C(x1, Poly).Linear', 'C(x1, Poly).Quadratic', 'C(x1, Poly).Cubic', 'C(x1, Poly)^4', 'C(x1, Poly)^5', 'C(x1, Poly)^6', 'C(x1, Poly)^7']
Terms:
'Intercept' (column 0), 'C(x1, Poly)' (columns 1:8)
(to view full data, use np.asarray(this_obj))
>>> dm = dmatrix('C(x1, Poly)', data)
>>> np.asarray(dm)
array([[ 1.        ,  0.23145502, -0.23145502, -0.43082022, -0.12087344,
         0.36376642,  0.55391171,  0.35846409],
       [ 1.        , -0.23145502, -0.23145502,  0.43082022, -0.12087344,
        -0.36376642,  0.55391171, -0.35846409],
       [ 1.        ,  0.07715167, -0.38575837, -0.18463724,  0.36262033,
         0.32097037, -0.30772873, -0.59744015],
       [ 1.        ,  0.54006172,  0.54006172,  0.43082022,  0.28203804,
         0.14978617,  0.06154575,  0.01706972],
       [ 1.        ,  0.38575837,  0.07715167, -0.30772873, -0.52378493,
        -0.49215457, -0.30772873, -0.11948803],
       [ 1.        , -0.54006172,  0.54006172, -0.43082022,  0.28203804,
        -0.14978617,  0.06154575, -0.01706972],
       [ 1.        , -0.07715167, -0.38575837,  0.18463724,  0.36262033,
        -0.32097037, -0.30772873,  0.59744015],
       [ 1.        , -0.38575837,  0.07715167,  0.30772873, -0.52378493,
         0.49215457, -0.30772873,  0.11948803]])