导入数据集时如何随机复制某些文档?

时间:2018-06-22 07:19:40

标签: python scikit-learn

我正在做一个项目,在这里我必须检测数据集中存在的重复项。为了创建模型,我从sklearn获取了数据集20newsgroup。

from sklearn.datasets import fetch_20newsgroups

categories = [
    'alt.atheism',
    'talk.religion.misc',
    'comp.graphics',
    'sci.space',
]
print("Loading 20 newsgroups dataset for categories:")
data_train = fetch_20newsgroups(subset='train', categories=categories,
                                shuffle=True, random_state=42)

data_test = fetch_20newsgroups(subset='test', categories=categories,
                               shuffle=True, random_state=42)
print('data loaded')

但是该数据集具有唯一条目,因此我必须自己创建重复项。这里data_train是随机选择的用于训练模型的文档组成的数组。

有人知道要对这些文档进行随机复制吗,所以最后我得到了一个包含重复条目的数据集?

1 个答案:

答案 0 :(得分:2)

返回类型public function__construct(){ {$this->middleware('auth', ['except' => ['index']]); } 是一个fetch_20newsgroups对象。它在Bunch变量中包含文档,在data变量中包含对应的标签s。因此,target是一个列表,data_train.data是一个numpy数组。导入数据集后,您可能使用了data_train.targetdata_train.data。下面的代码是从这些容器中复制一行。

data_train.target

import random
def duplicate(X, y):
    index = random.randint(0, len(X) - 1)
    X.append(X[index])
    y = np.append(y, y[index])
    return X, y

X = data_train.data
y = data_train.target

print(len(X))
print(len(y))

X, y = duplicate(X, y)

print(len(X))
print(len(y))

您也可以对>>> 2034 >>> 2034 >>> 2035 >>> 2035 做同样的事情。 data_test函数复制一行并返回文档,标签。您可能需要扩展该功能,以通过一次调用复制多行。

注意:如果要让duplicate对象具有重复的行。复制Bunch的行后,您可能会做类似data_train.data = X的操作,但是我不熟悉此对象类型,因此不确定该对象的行为。

修改

对于多个重复的行,可以多次调用上面的函数。多个重复项的实现效率更高,如下所示:

X