test_train_split将字符串类型的标签转换为np.array。有什么方法可以找回原始标签名称?

时间:2018-12-08 08:34:10

标签: python scikit-learn

我有一个带有字符串类型标签名称的图像数据集。当我使用sklearn库的test_train_split拆分数据时,它将标签转换为np.array类型。有没有办法找回原始的字符串类型标签名称?

下面的代码拆分数据进行训练和测试:

imgs, y = load_images()
train_img,ytrain_img,test_img,ytest_img = train_test_split(imgs,y, test_size=0.2, random_state=1)

如果我打印y,它会给我标签名称,但是如果我打印分割的标签值,它会得到一个数组:

for k in y:
    print(k)
    break
for k in ytrain_img:
    print(k)
    break

输出:

 001.Affenpinscher
 [[[ 97 180 165]
  [ 93 174 159]
  [ 91 169 152]
  ...
 [[ 88 171 156]
 [ 88 170 152]
 [ 84 162 145]
 ...
 [130 209 222]
 [142 220 233]
 [152 230 243]]

 [[ 99 181 163]
 [ 98 178 161]
 [ 92 167 151]
 ...
 [130 212 224]
 [137 216 229]
 [143 222 235]]
 ...
 [[ 85 147 158]
 [ 85 147 158]
 [111 173 184]
 ...
 [227 237 244]
 [236 248 250]
 [234 248 247]]

 [[ 94 154 166]
 [ 96 156 168]
 [133 194 204]
 ...
[226 238 244]
[237 249 253]
[237 252 254]]
...
[228 240 246]
[238 252 255]
[241 255 255]]]

是否可以将数组转换回原始标签名称?

1 个答案:

答案 0 :(得分:0)

否,您在推断train_test_split的输出错误。

train_test_split的工作方式如下:

A_train, A_test, B_train, B_test, C_train, C_test ... 
                             = train_test_split(A, B, C ..., test_size=0.2)

您可以分配多个数组。对于每个给定的数组,它将首先提供训练和测试拆分,然后对下一个数组进行相同的操作,然后对第三个数组进行操作,依此类推。

实际上,您的情况是:

train_img, test_img, ytrain_img, ytest_img = train_test_split(imgs, y, 
                                                              test_size=0.2, 
                                                              random_state=1)

但是您随后混淆了输出名称,并错误地使用了它们。