如何在朱莉娅进行交叉验证(K-fold)?

时间:2017-11-10 14:02:52

标签: julia

假设我有一个包含两列的数据集。我在我的数据集上建立了线性回归模型。现在我的问题是如何检查模型的准确性。

我发现我的问题的答案是在我的数据集上应用K-fold。我知道K-fold是如何工作的,但我不知道如何在我的Julia程序中实现K-fold。

#suppose I have two columns x and y in my dataset

 x= [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20]
 y=[2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]

# now how do I use K-fold to split dataset and also evaluate my algorithm?

2 个答案:

答案 0 :(得分:4)

如评论中所述,一旦给出任何基本来源,就更容易设置一些代码。例如,在这种情况下,K折叠交叉验证可能需要经历如下准备:

julia> x= [1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20];

julia> y=[2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21];

julia> K = 5 # number of folds in validation
5

julia> N = length(x) # number of samples in dataset
20

julia> stops = round.(Int,linspace(1,N,K+1))
6-element Array{Int64,1}:
  1
  5
  9
 12
 16
 20

julia> vsets = [s:e-(e<N)*1 for (s,e) in zip(stops[1:end-1],stops[2:end])]
5-element Array{UnitRange{Int64},1}:
 1:4  
 5:8  
 9:11 
 12:15
 16:20

julia> tsets1 = [1:s-1 for (s,e) in zip(stops[1:end-1],stops[2:end])]
5-element Array{UnitRange{Int64},1}:
 1:0 
 1:4 
 1:8 
 1:11
 1:15

julia> tsets2 = [e+(e<=N)*1:N for (s,e) in zip(stops[1:end-1],stops[2:end])]
5-element Array{UnitRange{Int64},1}:
 6:20 
 10:20
 13:20
 17:20
 21:20

julia> σ = randperm(N);

julia> [x[σ[vsets[i]]] for i=1:K]   # validation sets
5-element Array{Array{Int64,1},1}:
 [5, 13, 6, 10]   
 [16, 4, 2, 3]    
 [9, 19, 20]      
 [17, 12, 14, 11] 
 [8, 1, 18, 7, 15]

julia> [x[vcat(σ[tsets1[i]],σ[tsets2[i]])] for i=1:K]   # training sets
5-element Array{Array{Int64,1},1}:
 [4, 2, 3, 9, 19, 20, 17, 12, 14, 11, 8, 1, 18, 7, 15]   
 [5, 13, 6, 10, 19, 20, 17, 12, 14, 11, 8, 1, 18, 7, 15] 
 [5, 13, 6, 10, 16, 4, 2, 3, 12, 14, 11, 8, 1, 18, 7, 15]
 [5, 13, 6, 10, 16, 4, 2, 3, 9, 19, 20, 1, 18, 7, 15]    
 [5, 13, 6, 10, 16, 4, 2, 3, 9, 19, 20, 17, 12, 14, 11]  

这可能是令人满意的。有关K-fold交叉验证的更多详细信息,请访问维基百科:https://en.wikipedia.org/wiki/Cross-validation_(statistics)#k-fold_cross-validation

答案 1 :(得分:2)

您可以使用MLDataUtils.jl中的folds

kfolds([1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20],5)