为什么LabelEncoder没有读取值?

时间:2016-08-29 15:40:16

标签: python scikit-learn one-hot-encoding

我尝试使用LabelEncoder和来自sklearn的OneHotEncoder对数据集进行1-hot-encoding,首先对每列进行LabelEncoding,然后对列进行OneHotEncoding。 注意:我有意将两列数据帧的第1行设为nan,这样LabelEncoder就不会丢失。

以下是代码:

training_data.dropna(axis=1,how='any',inplace=True)
print training_data.shape
rows = [1]
training_data.loc[rows, endocing_columns] = float("nan")


print training_data.loc[1].mail_category 
print training_data.loc[1].mail_type 
for col in endocing_columns:
    label_encoder=LabelEncoder()
    oneHot_encoder=OneHotEncoder(sparse=False)
    label_encoder.fit(training_data[col])
    temp_col = pd.DataFrame(label_encoder.transform(training_data[col]))

    oneHot_encoder.fit(temp_col)
    temp = oneHot_encoder.transform(temp_col)
    print training_data.shape
    temp=pd.DataFrame(temp)
    training_data[col].value_counts().index])
    # In side by side concatenation index values should be same
    # Setting the index values similar to the training_data data frame
    temp=temp.set_index(training_data.index.values)
    # adding the new One Hot Encoded varibales to the train data frame
    training_data=pd.concat([training_data,temp],axis=1)
    training_data.drop(col, axis=1, inplace=True)

    print label_encoder.classes_
    temp_col = pd.DataFrame(label_encoder.transform(test_data[col]))
    temp = oneHot_encoder.transform(temp_col)

这是代码的输出(请注意,在标签编码器的打印类中,有nan):

(478192, 46)
nan
nan
(478192, 46)
[nan 'mail_category_1' 'mail_category_10' 'mail_category_11'
 'mail_category_12' 'mail_category_13' 'mail_category_14'
 'mail_category_15' 'mail_category_16' 'mail_category_17'
 'mail_category_18' 'mail_category_2' 'mail_category_3' 'mail_category_4'
 'mail_category_5' 'mail_category_6' 'mail_category_7' 'mail_category_8'
 'mail_category_9']
Traceback (most recent call last):
  File "basic_analysis.py", line 46, in <module>
    temp_col = pd.DataFrame(label_encoder.transform(test_data[col]))
  File "/usr/local/lib/python2.7/dist-packages/sklearn/preprocessing/label.py", line 148, in transform
    raise ValueError("y contains new labels: %s" % str(diff))
ValueError: y contains new labels: [nan]

0 个答案:

没有答案