如何在Julia中使用`ModelMatrix`对象进行矩阵乘法?

时间:2017-01-04 21:01:17

标签: matrix dataframe julia

我没有R中DataFrame的经验,我实际上并不理解ModelMatrix

我用它将DataFrame对象转移到一个“矩阵”,其中一个额外的列完全由1个值组成。它对线性回归很有帮助。但是,我发现Julia不支持ModelMatrix的矩阵乘法。

当我尝试时:

# feature is a DataFrames.ModelMatrix{Array{Float64,2}} object
println(feature' * feature)  

我收到以下错误:

ERROR: LoadError: MethodError: no method matching *(::DataFrames.ModelMatrix{Array{Float64,2}}, ::DataFrames.ModelMatrix{Array{Float64,2}})

如果有人试图使用以下内容将ModelMatrix转换为Array

feature_array = convert(Array, feature)

然后出现错误:

ERROR: LoadError: MethodError: Cannot `convert` an object of type DataFrames.ModelMatrix{Array{Float64,2}} to an object of type Array{T,N}

因此,我想知道如何将ModelMatrix转换为另一个Julia可以进行矩阵乘法(*)的对象,如Array

1 个答案:

答案 0 :(得分:4)

如果您检查source of ModelMatrix,则可以看到该对象具有属性m,该属性是基础矩阵的值。您可以使用mm.mmmModelMatrix)将其拉出来。

实施例

生成ModelMatrix

julia> using DataFrames

julia> df = DataFrame(X = randn(4), Y = randn(4), Z = randn(4))
4×3 DataFrames.DataFrame
│ Row │ X        │ Y          │ Z        │
├─────┼──────────┼────────────┼──────────┤
│ 1   │ 0.766271 │ 0.669007   │ 0.232803 │
│ 2   │ 2.08208  │ 0.239115   │ 0.855068 │
│ 3   │ -1.48009 │ 0.00220079 │ 0.105638 │
│ 4   │ -1.57438 │ 0.650456   │ 0.557467 │

julia> mf = ModelFrame(Z ~ X + Y, df)
DataFrames.ModelFrame(4×3 DataFrames.DataFrame
│ Row │ Z        │ X        │ Y          │
├─────┼──────────┼──────────┼────────────┤
│ 1   │ 0.232803 │ 0.766271 │ 0.669007   │
│ 2   │ 0.855068 │ 2.08208  │ 0.239115   │
│ 3   │ 0.105638 │ -1.48009 │ 0.00220079 │
│ 4   │ 0.557467 │ -1.57438 │ 0.650456   │
...

julia> mm = ModelMatrix(mf)
DataFrames.ModelMatrix{Array{Float64,2}}(4x3 Array{Float64,2}:
 1.0   0.766271  0.669007  
 1.0   2.08208   0.239115  
 1.0  -1.48009   0.00220079
 1.0  -1.57438   0.650456  ,[0,1,2])

使用ModelMatrix

julia> m = mm.m
4x3 Array{Float64,2}:
 1.0   0.766271  0.669007  
 1.0   2.08208   0.239115  
 1.0  -1.48009   0.00220079
 1.0  -1.57438   0.650456  

julia> m * rand(3,1)
4x1 Array{Float64,2}:
  1.9474  
  3.08515 
 -0.522879
 -0.371708