为什么在使用knn()且k = 1的情况下我的模型如此精确?

时间:2018-07-03 15:56:53

标签: r plotly r-caret knn

我目前正在使用基因组表达水平,年龄和吸烟强度水平来预测肺癌患者的生存天数。我的数据很少; 173位患者和20,438个变量,包括基因表达水平(占20,436个)。我使用80:20的比例将数据分为测试和培训。数据中没有缺失值。

我正在使用knn()训练模型。代码如下所示:

prediction <- knn(train = trainData, test = testData, cl = trainAnswers, k=1)

在您注意到k = 1之前,似乎没有什么异常。 “为什么k = 1?”你可能会问。 k = 1的原因是因为当k = 1时,模型是最准确的。这对我来说毫无意义。有很多问题:

  1. 我正在使用knn()来预测连续变量。我应该使用类似于cox的东西。
  2. 该模型太准确了。以下是测试答案和模型预测的一些示例。对于第一位患者,死亡天数为274。该模型预测为268。对于第二位患者,测试:1147,预测:1135。第三,测试:354,预测:370。第四,测试:995,预测995 。 这怎么可能?在整个测试数据中,该模型平均仅运行了9.0625天!中位数差异为7天,模式为6天。这是结果图: Bar Graph

所以我想我的主要问题是knn()做什么,k代表什么,并且当k = 1时模型如何如此精确?这是我的完整代码(我无法附加实际数据):

# install.packages(c('caret', 'skimr', 'RANN', 'randomForest', 'fastAdaboost', 'gbm', 'xgboost', 'caretEnsemble', 'C50', 'earth'))
library(caret)

# Gather the data and store it in variables
LUAD <- read.csv('/Users/username/Documents/ClinicalData.csv')
geneData <- read.csv('/Users/username/Documents/GenomicExpressionLevelData.csv')
geneData <- data.frame(geneData)
row.names(geneData) = geneData$X
geneData <- geneData[2:514]
colNamesGeneData <- gsub(".","-",colnames(geneData),fixed = TRUE)
colnames(geneData) = colNamesGeneData

# Organize the data
# Important columns are 148 (smoking), 123 (OS Month, basically how many days old), and the gene data. And column 2 (barcode).
LUAD = data.frame(LUAD$patient, LUAD$TOBACCO_SMOKING_HISTORY_INDICATOR, LUAD$OS_MONTHS, LUAD$days_to_death)[complete.cases(data.frame(LUAD$patient, LUAD$TOBACCO_SMOKING_HISTORY_INDICATOR, LUAD$OS_MONTHS, LUAD$days_to_death)), ]
rownames(LUAD)=LUAD$LUAD.patient
LUAD <- LUAD[2:4]

# intersect(rownames(LUAD),colnames(geneData))
# ind=which(colnames(geneData)=="TCGA-778-7167-01A-11R-2066-07")
gene_expression=geneData[, rownames(LUAD)]

# Merge the two datasets to use the geneomic expression levels in your model
LUAD <- data.frame(LUAD,t(gene_expression))
LUAD.days_to_death <- LUAD[,3]
LUAD <- LUAD[,c(1:2,4:20438)]
LUAD <- data.frame(LUAD.days_to_death,LUAD)

set.seed(401)

# Number of Rows in the training data (createDataPartition(dataSet, percentForTraining, boolReturnAsList))
trainRowNum <- createDataPartition(LUAD$LUAD.days_to_death, p=0.8, list=FALSE)

# Training/Test Dataset
trainData <- LUAD[trainRowNum, ]
testData <- LUAD[-trainRowNum, ]

x = trainData[, c(2:20438)]
y = trainData$LUAD.days_to_death
v = testData[, c(2:20438)]
w = testData$LUAD.days_to_death

# Imputing missing values into the data
preProcess_missingdata_model <- preProcess(trainData, method='knnImpute')
library(RANN)
if (anyNA(trainData)) {
    trainData <- predict(preProcess_missingdata_model, newdata = trainData)
}
anyNA(trainData)

# Normalizing the data
preProcess_range_model <- preProcess(trainData, method='range')
trainData <- predict(preProcess_range_model, newdata = trainData)
trainData$LUAD.days_to_death <- y
apply(trainData[,1:20438], 2, FUN=function(x){c('min'=min(x), 'max'=max(x))})

preProcess_range_model_Test <- preProcess(testData, method='range')
testData <- predict(preProcess_range_model_Test, newdata = testData)
testData$LUAD.days_to_death <- w
apply(testData[,1:20438], 2, FUN=function(v){c('min'=min(v), 'max'=max(v))})

# To uncomment, select the text and press 'command' + 'shift' + 'c'
# set.seed(401)
# options(warn=-1)
# subsets <- c(1:10)
# ctrl <- rfeControl(functions = rfFuncs,
#                    method = "repeatedcv",
#                    repeats = 5,
#                    verbose = TRUE)
# lmProfile <- rfe(x=trainData[1:20437], y=trainAnswers,
#                  sizes = subsets,
#                  rfeControl = ctrl)
# lmProfile

trainAnswers <- trainData[,1]
testAnswers <- testData[,1]

library(class)
prediction <- knn(train = trainData, test = testData, cl = trainAnswers, k=1)

#install.packages("plotly")
library(plotly)
Test_Question_Number <- c(1:32)
prediction2 <- data.frame(prediction[1:32])
prediction2 <- as.numeric(as.vector(prediction2[c(1:32),]))
data <- data.frame(Test_Question_Number, prediction2, testAnswers)
names(data) <- c("Test Question Number","Prediction","Answer")

p <- plot_ly(data, x = ~Test_Question_Number, y = ~prediction2, type = 'bar', name = 'Prediction') %>%
    add_trace(y = ~testAnswers, name = 'Answer') %>%
    layout(yaxis = list(title = 'Days to Death'), barmode = 'group')
p
merge <- data.frame(prediction2,testAnswers)

difference <- abs((merge[,1])-(merge[,2]))
difference <- sort(difference)
meanDifference <- mean(difference)
medianDifference <- median(difference)
modeDifference <- names(table(difference))[table(difference)==max(table(difference))]
cat("Mean difference:", meanDifference, "\n")
cat("Median difference:", medianDifference, "\n")
cat("Mode difference:", modeDifference,"\n")

最后,出于澄清目的,ClinicalData.csv是年龄,死亡天数和吸烟强度数据。另一个.csv是基因组表达数据。第29行上方的数据并不重要,因此您只需跳到代码中显示“ set.seed(401)”的部分即可。

编辑:一些数据样本:

days_to_death    OS_MONTHS
121              3.98

NACC1   2001.5708   2363.8063   1419.879
NACC2   58.2948     61.8157     43.4386
NADK    706.868     1053.4424   732.1562
NADSYN1 1628.7634   912.1034    638.6471
NAE1    832.8825    793.3014    689.7123
NAF1    140.3264    165.4858    186.355
NAGA    1523.3441   1524.4619   1858.9074
NAGK    983.6809    899.869     1168.2003
NAGLU   621.3457    510.9453    1172.511
NAGPA   346.9762    257.5654    275.5533
NAGS    460.7732    107.2116    321.9763
NAIF1   217.1219    202.5108    132.3054
NAIP    101.2305    87.8942     77.261
NALCN   13.9628     36.7031     48.0809
NAMPT   3245.6584   1257.8849   5465.6387

1 个答案:

答案 0 :(得分:0)

因为K = 1是最复杂的knn模型。它具有最灵活的决策边界。它会产生过度拟合。在保持数据集上表现不佳(但并非总是如此)时,它将在训练数据中表现良好。