我正在阅读包含两列的CSV
文件。第二列描述了标签。我想看看我的CSV
文件中存在多少个标签。
我的解决方案涉及一个简单的for
循环和一个dictionary
对象:
dataset = np.genfromtxt(input_file, invalid_raise=False, missing_values='N/A', delimiter=",", dtype=str,
skip_header=1)
np.load
X = dataset[:, 0]
y = dataset[:, 1]
classes = dict()
for label in y:
if label in classes:
classes[label] += 1
else:
classes[label] = 1
print classes
示例:
{'Error Processing Payment': 1, 'General Question': 1, 'Display': 5, 'Software': 2}
我想知道是否有NumPy
函数,例如groupby
,它会给我相同的功能吗?
答案 0 :(得分:1)
您可以通过将数据集转换为结构化数组来使用numpy的花式索引:
dataset = np.genfromtxt(input_file, invalid_raise=False, missing_values='N/A', delimiter=",", dtype=[('data', 'S50'), ('label', 'S50')],
skip_header=1)
然后您获得'Error Processing Payment'
的频率就像:
len(dataset[dataset['label'] == 'Error Processing Payment'])
另外,您可以使用以下方式获取所有可用标签:
set(dataset['label'])