如何更改设计矩阵的数据维度?

时间:2015-02-25 10:54:27

标签: python matrix

我正在为我的数据创建设计矩阵。但是当我尝试打印子集2的矩阵的错误时出现

'*ValueError: all the input arrays must have same number of dimensions*.' 

该函数适用于第一个和第三个子集,但我不知道如何更改函数以考虑所有三个子集。我哪里错了?

def design_matrix(matrix): 
    one_matrix = np.ones(matrix.shape[0]).reshape(matrix.shape[0], 1)
    return np.concatenate((one_matrix, matrix), axis=1)

#Load the dataset
train_data = np.genfromtxt('data.dt')
subset_1 = train_data[:, 3:4]
subset_2 = train_data[:, 5]
subset_3 = train_data[:, 1:5]

design_1 = design_matrix(subset_1)
design_2 = design_matrix(subset_2)
design_3 = design_matrix(subset_3)

print design_2

子集2的数据如下所示。我不知道如何更改尺寸,以便前3行没有额外的列。

[  47.    63.    60.    39.    28.    26.    22.    11.    21.    40.    78.
   122.   103.    73.    47.    35.    11.     5.    16.    34.    70.   81.
   111.   101.    73.    40.    20.    16.     5.    11.    22.    40.   60.
   80.9   83.4   47.7   47.8   30.7   12.2    9.6   10.2   32.4   47.6
   54.    62.9   85.9   61.2   45.1   36.4   20.9   11.4   37.8   69.8
   106.1  100.8   81.6   66.5   34.8   30.6    7.    19.8   92.5  154.4
   125.9   84.8   68.1   38.5   22.8   10.2   24.1   82.9  132.   130.9
   118.1   89.9   66.6   60.    46.9   41.    21.3   16.     6.4    4.1
   6.8   14.5   34.    45.    43.1   47.5   42.2   28.1   10.1    8.1
   2.5    0.     1.4    5.    12.2   13.9   35.4   45.8   41.    30.1
   23.9   15.6    6.6    4.     1.8    8.5   16.6   36.3   49.6   64.2
   67.    70.9   47.8   27.5    8.5   13.2   56.9  121.5  138.3  103.2
   85.7   64.6   36.7   24.2   10.7   15.    40.1   61.5   98.5  124.7
   96.3   66.6   64.5   54.1   39.    20.6    6.7    4.3   22.7   54.8
   93.8   95.8   77.2   59.1   44.    47.    30.5   16.3    7.3   37.6
   74.   139.   111.2  101.6   66.2   44.7   17.    11.3   12.4    3.4
   6.    32.3   54.3   59.7   63.7   63.5   52.2   25.4   13.1    6.8
   6.3    7.1   35.6   73.    85.1   78.    64.    41.8   26.2   26.7
   12.1    9.5    2.7    5.    24.4   42.    63.5   53.8   62.    48.5
   43.9   18.6    5.7    3.6    1.4    9.6   47.4]

1 个答案:

答案 0 :(得分:0)

one_matrixmatrix具有不同的形状(202,1)202。这就是错误信息所说的内容。