更改数组的标签值的细微错误

时间:2019-03-14 14:52:18

标签: dictionary

问题

我正在尝试将原始标签769、770、771和772映射到0、1、2和3。但是,当我尝试使用字典进行此操作时,下面的y_test不会改变。

请注意,dataDict是经过预处理的字典,其中键“ y_test”和“ y_train_valid”都对应于int32类型的一维数组。

mappingDict = {769: 0, 770: 1, 771: 2, 772: 3}
y_train = dataDict["y_train_valid"].copy()
y_test = dataDict["y_test"].copy()
for label, newLabel in mappingDict.items():
    y_train[y_train == label] = newLabel
    y_test[y_test == label] == newLabel

MWE

要重现发生的情况,您可以尝试以下操作,最终得到 enter image description here

y_train = np.array([771, 772, 769, 769, 769, 769, 771, 770, 772, 772], dtype="int32")
y_test = np.array([770, 769, 771, 772, 772, 771, 771, 772, 772, 769], dtype="int32")
mappingDict = {769: 0, 770: 1, 771: 2, 772: 3}

for label, newLabel in mappingDict.items():
    y_train[y_train == label] = newLabel
    y_test[y_test == label] == newLabel

1 个答案:

答案 0 :(得分:0)

这是一个愚蠢的错误,但花了我30分钟。

我以某种方式在=中添加了另外一个y_test[y_test == label] = newLabel,结果是y_test[y_test == label] == newLabel