如何理解拆分数据的功能

时间:2019-12-10 17:18:31

标签: python scikit-learn training-data k-fold

有人可以帮助我了解此功能的作用吗?

我了解行打印的内容,但是之后我有点迷茫。从train_data开始。

def stratifiedShuffleSplit_data(X, y):
    sss = StratifiedShuffleSplit(n_splits=5, test_size=0.5, random_state=0)
    for train_index, test_index in sss.split(X, y):
        print("len(TRAIN):", len(train_index), "len(TEST):", len(test_index))
        print("TRAIN:", train_index, "TEST:", test_index)

        train_data = [df.loc[ind] for ind in train_index]
        test_data = [df.loc[ind] for ind in test_index]
        save_datarows(train_data, datafile+".train")
        save_datarows(test_data, datafile+".test")

1 个答案:

答案 0 :(得分:0)

假设您使用的是熊猫包装,

 pd.DataFrame.loc 

是一种基于位置的索引器-这是一个过于简化的版本。我将发布一些资源,以帮助您更好地理解它。

train_data = [df.loc[ind] for ind in train_index]

这里,您基本上遍历了ind列表并存储了各自的值train_data 对于test_data的情况类似

我假设save_datarows是一个自定义函数,用于将train_data存储到扩展名为.train的文件中

希望这会有所帮助。

这是一个很好的参考资料,供您进一步澄清:

Selection with .loc in python

https://www.geeksforgeeks.org/python-pandas-dataframe-loc/