Numpy csv文件groupby

时间:2016-07-01 22:38:59

标签: python csv numpy

我正在阅读包含两列的CSV文件。第二列描述了标签。我想看看我的CSV文件中存在多少个标签。

我的解决方案涉及一个简单的for循环和一个dictionary对象:

dataset = np.genfromtxt(input_file, invalid_raise=False, missing_values='N/A', delimiter=",", dtype=str,
                            skip_header=1)
    np.load

    X = dataset[:, 0]
    y = dataset[:, 1]
    classes = dict()
    for label in y:
        if label in classes:
            classes[label] += 1
        else:
            classes[label] = 1

    print classes

示例:

{'Error Processing Payment': 1, 'General Question': 1, 'Display': 5, 'Software': 2}

我想知道是否有NumPy函数,例如groupby,它会给我相同的功能吗?

1 个答案:

答案 0 :(得分:1)

您可以通过将数据集转换为结构化数组来使用numpy的花式索引:

dataset = np.genfromtxt(input_file, invalid_raise=False, missing_values='N/A', delimiter=",", dtype=[('data', 'S50'), ('label', 'S50')],
                        skip_header=1)

然后您获得'Error Processing Payment'的频率就像:

一样简单
len(dataset[dataset['label'] == 'Error Processing Payment'])

另外,您可以使用以下方式获取所有可用标签:

set(dataset['label'])