如何在没有滑雪套件学习的情况下为K折叠交叉验证创建训练集?

时间:2020-03-09 03:03:06

标签: python numpy machine-learning cross-validation k-fold

我有一个包含95行和9列的数据集,想要进行5倍交叉验证。在训练中,前8列(功能)用于预测第9列。我的测试集是正确的,但是当x训练集应该只有8列时,我的x训练集的大小为(4,19,9),而当它应该有19行时我是y训练集。我对子数组的索引不正确吗?

kdata = data[0:95,:] # Need total rows to be divisible by 5, so ignore last 2 rows 
np.random.shuffle(kdata) # Shuffle all rows
folds = np.array_split(kdata, k) # each fold is 19 rows x 9 columns

for i in range (k-1):
    xtest = folds[i][:,0:7] # Set ith fold to be test
    ytest = folds[i][:,8]
    new_folds = np.delete(folds,i,0)
    xtrain = new_folds[:][:][0:7] # training set is all folds, all rows x 8 cols
    ytrain = new_folds[:][:][8]   # training y is all folds, all rows x 1 col

1 个答案:

答案 0 :(得分:1)

欢迎堆栈溢出。

一旦创建了新的折叠,就需要使用np.row_stack()按行堆叠它们。

此外,我认为您对数组的切片不正确,在Python或Numpy中,切片行为为[inclusive:exclusive],因此,当您将切片指定为[0:7]时,您只会占用7列,而不是8个预期的功能列。

类似地,如果您在for循环中指定5折,则应该是range(k)才能获得[0,1,2,3,4],而不是range(k-1)只能得到[0,1,2,3]。 / p>

修改后的代码:

folds = np.array_split(kdata, k) # each fold is 19 rows x 9 columns
np.random.shuffle(kdata) # Shuffle all rows
folds = np.array_split(kdata, k)

for i in range (k):
    xtest = folds[i][:,:8] # Set ith fold to be test
    ytest = folds[i][:,8]
    new_folds = np.row_stack(np.delete(folds,i,0))
    xtrain = new_folds[:, :8]
    ytrain = new_folds[:,8]

    # some print functions to help you debug
    print(f'Fold {i}')
    print(f'xtest shape  : {xtest.shape}')
    print(f'ytest shape  : {ytest.shape}')
    print(f'xtrain shape : {xtrain.shape}')
    print(f'ytrain shape : {ytrain.shape}\n')

这将为您打印出折叠以及所需的训练和测试集的形状:

Fold 0
xtest shape  : (19, 8)
ytest shape  : (19,)
xtrain shape : (76, 8)
ytrain shape : (76,)

Fold 1
xtest shape  : (19, 8)
ytest shape  : (19,)
xtrain shape : (76, 8)
ytrain shape : (76,)

Fold 2
xtest shape  : (19, 8)
ytest shape  : (19,)
xtrain shape : (76, 8)
ytrain shape : (76,)

Fold 3
xtest shape  : (19, 8)
ytest shape  : (19,)
xtrain shape : (76, 8)
ytrain shape : (76,)

Fold 4
xtest shape  : (19, 8)
ytest shape  : (19,)
xtrain shape : (76, 8)
ytrain shape : (76,)