使用一种热编码对类别数据进行预测

时间:2019-10-15 03:29:46

标签: python dataframe

我使用此代码来确定攻击类型,因此我对7个属性使用了一种热编码,对于目标使用了标签编码 但是数据集的内容超过180000行,具有8个属性,所以我要确保其正确

l = LabelEncoder()
df1_onehot = df1.copy()
df1_onehot = pd.get_dummies(df1_onehot, columns=[‘gname’], prefix = [‘gname’])
df1_onehot = pd.get_dummies(df1_onehot, columns=[‘city’], prefix = [‘city’])
df1_onehot = pd.get_dummies(df1_onehot, columns=[‘region_txt’], prefix = [‘region_txt’])
df1_onehot = pd.get_dummies(df1_onehot, columns=[‘weaptype1_txt’], prefix = [‘weaptype1_txt’])
df1_onehot = pd.get_dummies(df1_onehot, columns=[‘country_txt’], prefix = [‘country_txt’])
#df1_onehot[‘attacktype1_txt’] = 
l.fit_transform(df1[‘attacktype1_txt’])
print(df1_onehot.head())

# Split-out validation dataset
from sklearn import model_selection
array = df1_onehot.values
X = array[:,0:11372]
Y = array[:,2]
validation_size = 0.20
seed = 4
X_train, X_validation, Y_train, Y_validation = 
model_selection.train_test_split(X, Y, test_size=validation_size, random_state=seed)

from sklearn.metrics import accuracy_score
from sklearn.linear_model import LogisticRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.naive_bayes import GaussianNB
from sklearn.svm import SVC
seed = 7
scoring = ‘accuracy’
  models = []
  models.append((‘LR’, LogisticRegression(solver=’liblinear’, 
multi_class=’ovr’)))
models.append((‘LDA’, LinearDiscriminantAnalysis()))
models.append((‘KNN’, KNeighborsClassifier()))
models.append((‘CART’, DecisionTreeClassifier()))
models.append((‘NB’, GaussianNB()))
models.append((‘SVM’, SVC(gamma=’auto’)))
# evaluate each model in turn
results = []
names = []
for name, model in models:
kfold = model_selection.KFold(n_splits=10, random_state=seed)
cv_results = model_selection.cross_val_score(model, X_train, Y_train, cv=kfold, scoring=scoring)
results.append(cv_results)
names.append(name)
msg = “%s: %f (%f)” % (name, cv_results.mean(), cv_results.std())

print(msg)

0 个答案:

没有答案