我正在使用Tensorflow框架进行分类预测。我的数据集包含大约1160个输出类。输出类值为6位数。例如,789954。在使用Tensorflow训练和测试数据集后,我得到了大约99%的准确度。
现在第二步是在csv文件中获取预测结果,以便我可以检查预测结果(logits)与集合中的原始标签匹配。我们知道logits是我的一个热门编码向量。所以,我已经完成了以下步骤来解码一个热编码。
prediction=tf.argmax(logits,1)
print(prediction.eval(feed_dict={features : test_features, keep_prob: 1.0}))
prediction = np.asarray(prediction.eval(feed_dict={features : test_features, keep_prob: 1.0}))
prediction = np.reshape(prediction, (test_features.shape[0],1))
np.savetxt("prediction.csv", prediction, delimiter=",")
csv文件中的结果值对于所有条目仅为0.00E + 00。但我的期望是各个csv条目的6位数代码。我想我的单热编码出错了。
任何帮助都很明显。
已添加: 我有一个这样的热编码。
labels = tf.one_hot(labels, n_classes)
并且n_classes = 1160并且所有值将是6位数
答案 0 :(得分:1)
如果每个描述只有one-label
,那么您的方法就可以了。您可以使用sklearn LabelEncoder
将类别转换为标签。您的标签应为每个要素添加[0 to 1160]
之间的值,然后执行on-hot encoding
。