在R中处理过多的时间

时间:2017-03-24 15:28:33

标签: r machine-learning

所以我在下面有这个代码,由于某些原因,knn函数需要花费大量时间才能运行,而且我不确定如何改进这个

library(tidyverse)
library(stringr)
library(class) # KNN
library(e1071) # SVM

train <- read_csv('https://raw.githubusercontent.com/idc9/stor390/master/data/human_activity_train.csv')
test <- read_csv('https://raw.githubusercontent.com/idc9/stor390/master/data/human_activity_test.csv')

train <- train %>% 
    filter(activity == 2 | activity == 3)

test <- test %>% 
    filter(activity == 2 | activity == 3)

get_knn_error_rates <- function(train_data, test_data, k){
    # computes KNN test/train error rates

    # break the train/test data into x matrix and y vectors
    test_data_x <- test %>% select(-activity)
    test_data_y <- test$activity

    train_data_x <- train %>% select(-activity)
    train_data_y <- train$activity

    # get predictions on training data
    knn_train_prediction <- knn(train=train_data_x, # training x
                            test=train_data_x, # test x
                            cl=train_data_y, # train y
                            k=k) # set k

    # get predictions on test data
    knn_test_prediction <- knn(train=train_data_x, # training x
                           test=test_data_x, # test x
                           cl=train_data_y, # train y
                            k=k) # set k

    # training error rate
    tr_err <- mean(train_data_y != knn_train_prediction)
    # training error rate
    tst_err <- mean(test_data_y != knn_test_prediction)

    list(tr=tr_err, tst=tst_err)
}

k_values <- seq(from=1, to= 41, by=2)
num_k <- length(k_values)

error_df <- tibble(k=rep(0, num_k),
                    tr=rep(0, num_k),
                    tst=rep(0, num_k))

for(i in 1:num_k){
    # fix k for this loop iteration
    k <- k_values[i]

    # get_knn_error_rates() is from the knn_functions.R script
    # it computes the train/test errors for knn
    errs <- get_knn_error_rates(train, test, k)

    # store values in the data frame
    error_df[i, 'k'] <- k
    error_df[i, 'tr'] <- errs[['tr']]
    error_df[i, 'tst'] <- errs[['tst']]
}

我使用system.time跟踪了最后一次for循环的计算时间,它给了我427.75秒

尝试交叉验证时,5倍,计算时间约为30分钟。

有没有办法改善这个?

0 个答案:

没有答案