从'lm()`预测'mlm'线性模型对象

时间:2016-09-18 03:27:33

标签: r regression linear-regression lm mlm

我有三个数据集:

响应 - 矩阵5(样本)×10(因变量)

预测变量 - 5个矩阵(样本)x 2(自变量)

test_set - 10个矩阵(样本)x 10(响应中定义的因变量)

response <- matrix(sample.int(15, size = 5*10, replace = TRUE), nrow = 5, ncol = 10)
colnames(response) <- c("1_DV","2_DV","3_DV","4_DV","5_DV","6_DV","7_DV","8_DV","9_DV","10_DV") 
predictors <- matrix(sample.int(15, size = 7*2, replace = TRUE), nrow = 5, ncol = 2)
colnames(predictors) <- c("1_IV","2_IV")
test_set <- matrix(sample.int(15, size = 10*2, replace = TRUE), nrow = 10, ncol = 2)
colnames(test_set) <- c("1_IV","2_IV")

我正在使用定义为响应和预测集的组合的训练集来进行多元线性模型,我想使用此模型对测试集进行预测:

training_dataframe <- data.frame(predictors, response)
fit <- lm(response ~ predictors, data = training_dataframe)
predictions <- predict(fit, data.frame(test_set))

然而,预测的结果真的很奇怪:

predictions

首先,矩阵尺寸为5 x 10,这是响应变量中的样本数量,以DV数量表示。

我对R中的这种分析并不是很熟练,但是我不应该得到一个10 x 10矩阵,这样我对test_set中的每一行都有预测?

对此问题的任何帮助将不胜感激, 马丁

1 个答案:

答案 0 :(得分:5)

你正在踩到R中一个支持不足的部分。你拥有的模型类是&#34; mlm&#34;,即&#34;多个线性模型&#34;,这不是标准&#34;流明&#34;类。当你有一组共同的协变量/预测变量的几个(独立的)响应变量时,你会得到它。虽然lm()函数可以适合这种模型,但predict方法对于#34; mlm&#34;类。如果您查看methods(predict),就会看到predict.mlm*。通常用于带有&#34; lm&#34;的线性模型。当您致电predict.lm时,系统会调用predict;但对于一个&#34; mlm&#34;类predict.mlm*被调用。

predict.mlm*太原始了。它不允许se.fit,即它不能产生预测误差,置信度/预测间隔等,尽管这在理论上是可能的。它只能计算预测均值。如果是这样,我们为什么要使用predict.mlm*?预测均值可以通过平凡的矩阵 - 矩阵乘法获得(在标准&#34; lm&#34;类中这是一个矩阵向量乘法),所以我们可以自己做。

考虑这个小的,重现的例子。

set.seed(0)
## 2 response of 10 observations each
response <- matrix(rnorm(20), 10, 2)
## 3 covariates with 10 observations each
predictors <- matrix(rnorm(30), 10, 3)
fit <- lm(response ~ predictors)

class(fit)
# [1] "mlm" "lm"

beta <- coef(fit)
#                  [,1]       [,2]
#(Intercept)  0.5773235 -0.4752326
#predictors1 -0.9942677  0.6759778
#predictors2 -1.3306272  0.8322564
#predictors3 -0.5533336  0.6218942

当您有预测数据集时:

# 2 new observations for 3 covariats
test_set <- matrix(rnorm(6), 2, 3)

我们首先需要填充拦截列

Xp <- cbind(1, test_set)

然后进行矩阵乘法

pred <- Xp %*% beta
#          [,1]      [,2]
#[1,] -2.905469  1.702384
#[2,]  1.871755 -1.236240

也许您已经注意到我甚至没有在这里使用数据框。 是的,因为你拥有矩阵形式的所有内容,所以没有必要。对于那些R向导,使用lm.fit甚至qr.solve可能更直接。

但作为一个完整的答案,必须证明如何使用predict.mlm来获得我们想要的结果。

## still using previous matrices
training_dataframe <- data.frame(response = I(response), predictors = I(predictors))
fit <- lm(response ~ predictors, data = training_dataframe)
newdat <- data.frame(predictors = I(test_set))
pred <- predict(fit, newdat)
#          [,1]      [,2]
#[1,] -2.905469  1.702384
#[2,]  1.871755 -1.236240

使用I()时请注意data.frame()。当我们想要获得矩阵的数据框时,这是必须的。您可以比较两者之间的区别:

str(data.frame(response = I(response), predictors = I(predictors)))
#'data.frame':  10 obs. of  2 variables:
# $ response  : AsIs [1:10, 1:2] 1.262954.... -0.32623.... 1.329799.... 1.272429.... 0.414641.... ...
# $ predictors: AsIs [1:10, 1:3] -0.22426.... 0.377395.... 0.133336.... 0.804189.... -0.05710.... ...

str(data.frame(response = response, predictors = predictors))
#'data.frame':  10 obs. of  5 variables:
# $ response.1  : num  1.263 -0.326 1.33 1.272 0.415 ...
# $ response.2  : num  0.764 -0.799 -1.148 -0.289 -0.299 ...
# $ predictors.1: num  -0.2243 0.3774 0.1333 0.8042 -0.0571 ...
# $ predictors.2: num  -0.236 -0.543 -0.433 -0.649 0.727 ...
# $ predictors.3: num  1.758 0.561 -0.453 -0.832 -1.167 ...

如果没有I()来保护矩阵输入,数据会很混乱。令人惊讶的是,这不会给lm带来问题,但如果您不使用predict.mlmI()将很难获得正确的预测矩阵。

好吧,我建议使用&#34;列表&#34;而不是&#34;数据框&#34;在这种情况下。data中的{{strong} lm参数以及newdata中的predict参数允许列表输入。 A&#34;列表&#34;是一种比数据框架更通用的结构,它可以毫无困难地保存任何数据结构。我们可以这样做:

## still using previous matrices
training_list <- list(response = response, predictors = predictors)
fit <- lm(response ~ predictors, data = training_list)
newdat <- list(predictors = test_set)
pred <- predict(fit, newdat)
#          [,1]      [,2]
#[1,] -2.905469  1.702384
#[2,]  1.871755 -1.236240

也许在最后,我应该强调使用公式接口而不是矩阵接口总是安全的。我将使用R内置数据集trees作为可重现的示例

fit <- lm(cbind(Girth, Height) ~ Volume, data = trees)

## use the first two rows as prediction dataset
predict(fit, newdata = trees[1:2, ])
#     Girth   Height
#1 9.579568 71.39192
#2 9.579568 71.39192

也许你还记得我的说法predict.mlm*过于原始而无法支持se.fit。这是测试它的机会。

predict(fit, newdata = trees[1:2, ], se.fit = TRUE)
#Error in predict.mlm(fit, newdata = trees[1:2, ], se.fit = TRUE) : 
#  the 'se.fit' argument is not yet implemented for "mlm" objects

哎呀......置信度/预测间隔(实际上没有计算标准误差的能力,不可能产生这些间隔)?好吧,predict.mlm*会忽略它。

predict(fit, newdata = trees[1:2, ], interval = "confidence")
#     Girth   Height
#1 9.579568 71.39192
#2 9.579568 71.39192

predict.lm相比,这是如此不同。