概率SVM,如何在“热图”内绘制点?

时间:2019-01-10 14:50:57

标签: python matplotlib

我目前正在尝试实现一个概率svm,到目前为止,我对房间的概率有了一个热图,即:

enter image description here

现在,我希望我的观点发生在该情节中,但是当我尝试以自己的方式看时,我得到:

enter image description here

代码看起来像这样:

import numpy as np
import random
from sklearn.svm import SVC
import math
from scipy.optimize import minimize
import matplotlib.pyplot as plt
from sklearn import decomposition
from sklearn import svm
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
import warnings
warnings.filterwarnings("ignore")

def training_banana(name):
    inputs = []
    file = open(name, "r")
    for line in file:
        vector = line.split()
        coordinate = []
        for i in range(len(vector)):
            coordinate.append(float(vector[i]))
        inputs.append(coordinate)
    file.close()

    return np.array(inputs)


def define_inputs(name, name_targets):
    inputs = training_banana(name)
    targets_array = training_banana(name_targets)
    N = targets_array.shape[0]
    targets = np.zeros(N)
    for i in range(N):
        targets[i] = targets_array[i][0]
    return inputs, targets, N

#training set
inputs_train, targets_train, N = define_inputs('banana_train.txt', 'banana_train_label.txt')
permute = list(range(N))
random.shuffle(permute)
inputs_train = inputs_train[permute, :]
targets_train = targets_train[permute]

#test set
inputs_test, targets_test, N = define_inputs('banana_test.txt', 'banana_test_label.txt')
permute = list(range(N))
random.shuffle(permute)
inputs_test = inputs_test[permute, :]
targets_test = targets_test[permute]


def plotting():
    ax = plt.gca()
    param_C = [0.01, 0.1, 1, 10, 100]
    param_grid = {'C': param_C, 'kernel': ['poly','rbf', 'linear'], 'gamma': [0.1, 0.01, 0.001, 0.0001]}
    clf = GridSearchCV(SVC(class_weight='balanced'), param_grid)
    clf.fit(inputs_train, targets_train)
    index = clf.best_estimator_.n_support_
    clf = SVC(C=clf.best_params_['C'], cache_size=200, class_weight=None, coef0=0.0,
        decision_function_shape='ovr', degree=5, gamma=clf.best_params_['gamma'], kernel=clf.best_params_['kernel'],
        max_iter=-1, probability=True, random_state=None, shrinking=True,
        tol=0.001, verbose=False)
    clf.fit(inputs_train, targets_train)
    support_vectors = []
    for i in range(len(index)):
        support_vectors.append(inputs_train[i])
    support_vectors = np.array(support_vectors)
    xx = np.linspace(-2, 2, 100)
    yy = np.linspace(-2, 2, 100).T
    xx, yy = np.meshgrid(xx, yy)
    Xfull = np.c_[xx.ravel(), yy.ravel()]
    probabilities = clf.predict_proba(inputs_test)
    predicting_classes_pos_targets = []
    predicting_classes_pos_inputs = []
    predicting_classes_neg_targets = []
    predicting_classes_neg_inputs = []
    prob_mesh = clf.predict_proba(Xfull)
    #print(Xfull)
    print(probabilities)
    for i in range(inputs_test.shape[0]):
        if clf.predict([inputs_test[0]]) == 1:
            predicting_classes_pos_targets.append(1)
            predicting_classes_pos_inputs.append(inputs_test[i])
        else:
            predicting_classes_neg_targets.append(-1)
            predicting_classes_pos_inputs.append(inputs_test[i])
    predicting_classes_pos_inputs = np.array(predicting_classes_pos_inputs)

    #plt.scatter(predicting_classes_pos_inputs[:, 0], predicting_classes_pos_inputs[:, 1], c="b", s=30, cmap=plt.cm.Paired)
    plt.imshow(prob_mesh[:, 0].reshape((100, 100)),
                           extent=(3, 9, 1, 5), origin='lower')
    plt.scatter(predicting_classes_pos_inputs[:, 0], predicting_classes_pos_inputs[:, 1], marker='o', c='w', edgecolor='k')
    #plt.colorbar(imshow_handle, cax=ax, orientation='horizontal')
    plt.show()

任何人都知道如何更改代码以使这些点出现在“概率热图”中吗?

0 个答案:

没有答案