如何获得分类中的班级名称?

时间:2019-06-29 19:26:25

标签: machine-learning classification python-3.6 svm

我是ML的初学者,我建立了一个SVM模型来对一些输入进行分类。 我用熊猫来读取我的数据集。分类结果被打印为索引,每个索引都与我的数据集中的标签(类)名称相对应。如何将这些索引转换为它们的名称(字符串)?

例如,我有三个类:[Question,General,Info],但是当我尝试对输入进行分类时,结果是以下数字之一:[0,1,2] 我想将这些数字转换为我拥有的类的名称。

这是我的代码的一部分:

data = pandas.read_csv("classes.csv",encoding='utf-16' )


Train_X, Test_X, Train_Y, Test_Y = sklearn.model_selection.train_test_split(data['input'],data['Class'],test_size=0.3,random_state=None)

Test_YTrain_Y是数字(类)的列表,每个数字都指一个类,我怎么知道每个数字代表什么?

1 个答案:

答案 0 :(得分:0)

您需要知道的第一件事是:您的模型正在按预期工作。大多数情况下,它将为每个标签输出概率。因此,如果您的模型输出的内容类似于[0.1, 0.1, 0.8],则意味着您要分类的样本中有80%属于位置2的标签。如果按照问题中指示的顺序传递所有标签,即,[question, general, info],表示此样本属于info类。注意这里的顺序很重要,您需要确保在代码中输入模型时。

因此,要输出字符串而不是数字,您需要获取模型输出的数字,并检查包含此关系的列表或词典中的标签。以清单为例:

labels_str = ['question', 'general', 'info']

# preds is a np.array containing the probabilities
preds = model(some_sample)

# this function returns the position of the max value in the array
pos_pred = preds.argmax() 

print ("The label for this sample is {}".format(labels_str[pos_pred])

您知道吗?