我是ML的初学者,我建立了一个SVM模型来对一些输入进行分类。 我用熊猫来读取我的数据集。分类结果被打印为索引,每个索引都与我的数据集中的标签(类)名称相对应。如何将这些索引转换为它们的名称(字符串)?
例如,我有三个类:[Question,General,Info],但是当我尝试对输入进行分类时,结果是以下数字之一:[0,1,2] 我想将这些数字转换为我拥有的类的名称。
这是我的代码的一部分:
data = pandas.read_csv("classes.csv",encoding='utf-16' )
Train_X, Test_X, Train_Y, Test_Y = sklearn.model_selection.train_test_split(data['input'],data['Class'],test_size=0.3,random_state=None)
Test_Y
和Train_Y
是数字(类)的列表,每个数字都指一个类,我怎么知道每个数字代表什么?
答案 0 :(得分:0)
您需要知道的第一件事是:您的模型正在按预期工作。大多数情况下,它将为每个标签输出概率。因此,如果您的模型输出的内容类似于[0.1, 0.1, 0.8]
,则意味着您要分类的样本中有80%属于位置2的标签。如果按照问题中指示的顺序传递所有标签,即,[question, general, info]
,表示此样本属于info
类。注意这里的顺序很重要,您需要确保在代码中输入模型时。
因此,要输出字符串而不是数字,您需要获取模型输出的数字,并检查包含此关系的列表或词典中的标签。以清单为例:
labels_str = ['question', 'general', 'info']
# preds is a np.array containing the probabilities
preds = model(some_sample)
# this function returns the position of the max value in the array
pos_pred = preds.argmax()
print ("The label for this sample is {}".format(labels_str[pos_pred])
您知道吗?