带有Sklearn的MNIST数据集

时间:2019-11-17 01:08:49

标签: python mnist sklearn-pandas

我正在MNIST数据集上训练线性模型,但我只想训练一个数字4。如何选择X_test,X_train,y_test和y_train?

2 个答案:

答案 0 :(得分:0)

您的分类器需要学习区分不同类别的集合。 如果您只关心数字4,则应将训练和测试集划分为:

  • 4类实例
  • 不是第4类实例:所有其他数字的并集

否则,火车/测试拆分仍是典型的,您希望没有重叠。

答案 1 :(得分:0)

如果只需要识别4s,则是二进制分类问题,因此只需创建一个新的目标变量:如果class为4,则为Y = 1;如果class不是4,则为Y = 0。

  • Train_X将保持不变
  • Train_Y将是与Train_X相关的新目标变量
  • Test_X将保持不变
  • Test_Y将是与Test_X相关的新目标变量。 <\ ul>

    数据会有点不平衡,但这不应该成为问题!