
时间:2018-05-18 01:58:37

标签: python machine-learning scikit-learn random-forest resampling


  • 1级:10,000
  • 第2类:60,000
  • 3级:7,000
  • 第4类:5,000
  • 第5类:3,500
  • 6类&每个7:2000个样本
  • 第7-15类:每个样本1500个
  • 第16-22类:每个样本500个


import logging
import os
from sklearn.preprocessing import LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.model_selection import cross_val_score
from sklearn.externals import joblib
from sklearn.metrics import classification_report

logger = logging.getLogger(__name__)

def up_sample(data, labels, **kwargs):
    label_counts = Counter(labels)
    max_label = max(label_counts, key=label_counts.get)
    max_label_count = kwargs.get('samples', label_counts[max_label])
    output_text = []
    output_labels = []
    for label, count in label_counts.items():
        label_text = [data_row for data_row, label_row in zip(data, labels) if label_row == label]
        resampled_labels = [label] * max_label_count
        resampled_text = resample(label_text, n_samples=max_label_count, random_state=0)
        output_text = output_text + resampled_text
        output_labels = output_labels + resampled_labels
    return output_text, output_labels

clf = Pipeline(
    steps=(('tfidf_vectorizer', TfidfVectorizer(stop_words='english')),
    ('clf', RandomForestClassifier(n_estimators=250, n_jobs=-1)))

resampled_data, resampled_labels = upsample(data, labels) # UPDATE:  produces ~700,000 samples, which many duplicates

labels = label_encoder.fit_transform(labels)

X_train, X_test, y_train, y_test = train_test_split(data, labels, test_size=0.5, random_state=0) # UPDATE: many duplicates in both training and test data sets as a result of upsampling, y_train)

test_score = clf.score(X_test, y_test)
logger.debug('Test Score: %s', test_score) # 0.98-0.99%

cross_validation_results = cross_val_score(clf, data, labels)
logger.debug('Cross Validation results: %r', cross_validation_results) # [98.7, 99.1, 97.8]

y_test_predicted = clf.predict(X_test)
output_classification_report = classification_report(y_test, y_test_predicted, target_names=label_encoder.classes_)
logger.debug(output_classification_report)  # 0.95-1.0 for precision and recall for all classes

clf_file_name = os.path.join(directory, clf_name)
joblib.dump(clf, clf_file_name)

label_encoder_file_name = os.path.join(directory, label_encoder_name)
joblib.dump(label_encoder, label_encoder_file_name)

# Later, in a different script
clf_file_name = os.path.join(directory, name)
clf = joblib.load(clf_file_name)

label_encoder_file_name = os.path.join(directory, name)
label_encoder = joblib.load(label_encoder_file_name)

predictions = clf.predict(new_data)
logger.debug(clf.score(new_labels, predictions)) # 50-70%



