我想将CSV格式的百万记录数据集划分为80%用于培训,20%用于测试。如何使用Java或Weka库对此进行编码?
答案 0 :(得分:4)
您可以使用Instances
提供的方法首先随机化您的数据
Random rand = new Random(seed); // create seeded number generator
randData = new Instances(data); // create copy of original data
randData.randomize(rand); // randomize data with number generator
如果您的数据具有标称类,并且您想要执行分层交叉验证:
randData.stratify(folds);
现在,通常你会想要进行交叉验证并执行此操作:
for (int n = 0; n < folds; n++) {
Instances train = randData.trainCV(folds, n);
Instances test = randData.testCV(folds, n);
// further processing, classification, etc.
...
}
(来源实际上提到&#34;以上代码由weka.filters.supervised.instance.StratifiedRemoveFolds
过滤器使用&#34;)
但是,如果您只想要一组80/20实例,那么只需执行一次:
Instances train = randData.trainCV(folds, 0);
Instances test = randData.testCV(folds, 0);
答案 1 :(得分:2)
您可以使用名为StratifiedRemoveFolds
的过滤器,使用Weka库在Java中执行此操作// Load data
DataSource source = new DataSource("/some/where/data.csv");
Instances data = source.getDataSet();
// Set class to last attribute
if (data.classIndex() == -1)
data.setClassIndex(data.numAttributes() - 1);
// use StratifiedRemoveFolds to randomly split the data
StratifiedRemoveFolds filter = new StratifiedRemoveFolds();
// set options for creating the subset of data
String[] options = new String[6];
options[0] = "-N"; // indicate we want to set the number of folds
options[1] = Integer.toString(5); // split the data into five random folds
options[2] = "-F"; // indicate we want to select a specific fold
options[3] = Integer.toString(1); // select the first fold
options[4] = "-S"; // indicate we want to set the random seed
options[5] = Integer.toString(1); // set the random seed to 1
filter.setOptions(options); // set the filter options
filter.setInputFormat(data); // prepare the filter for the data format
filter.setInvertSelection(false); // do not invert the selection
// apply filter for test data here
Instances test = Filter.useFilter(data, filter);
// prepare and apply filter for training data here
filter.setInvertSelection(true); // invert the selection to get other data
Instances train = Filter.useFilter(data, filter);