在Python中使用LIBSVM支持向量和系数计算超平面方程

时间:2018-08-14 08:21:35

标签: python svm libsvm

我正在python中使用LIBSVM库,并尝试从计算出的支持向量中重建超平面的方程(w'x + b)。

该模型似乎可以正确训练,但是我无法手动计算与测试数据svm_predict的输出匹配的预测结果。

我已使用常见问题解答中的以下链接尝试进行故障排除,但我仍然无法计算出正确的结果。 https://www.csie.ntu.edu.tw/~cjlin/libsvm/faq.html#f804

我的代码如下:

from svmutil import *
import numpy as np

ytrain, xtrain = svm_read_problem('small_train.libsvm')
# Change labels from 0 to -1    
for index in range(len(ytrain)):
    if ytrain[index] == 0:
        ytrain[index] = -1.0
print ("Training set loaded...")

m = svm_train(ytrain, xtrain, '-q')
print ("Model trained...")

sv = np.asarray(m.get_SV())
sv_coef = np.asarray(m.get_sv_coef())
sv_indices = np.asarray(m.get_sv_indices())
rho = m.rho[0]

w = np.zeros(len(xtrain[0]))
b = -rho
# weight vector w = sum over i ( coefsi * xi )
for index, coef in zip(sv_indices, sv_coef):
    ai = coef[0]
    for key in xtrain[index-1]:
        w[key] = w[key] + (ai * xtrain[index-1][key])

# From LIBSVM FAQ - Doesn't seem to impact results
# if m.label[1] == -1:
#     w = np.negative(w)
#     b = -b

print(np.round(w,2))

ytest, xtest = svm_read_problem('small_test.libsvm')
# Change labels from 0 to -1  
for index in range(len(ytest)):
    if ytest[index] == 0:
        ytest[index] = -1.0

print ("Test set loaded...")
print ("Predict test set...")
p_label, p_acc, p_val = svm_predict(ytest, xtest, m)

print("p_label: ", p_label)
print("p_val: ", np.round(p_val,3))

for i in range(len(ytest)):
    wx = 0
    for key in xtest[i]:
        wx = wx + (xtest[i][key] * w[key])
    print("Manual calc: ", np.round(wx + b,3))

我的理解是,我使用wx + b手动计算的结果应与p_val中包含的结果相匹配。我尝试对w和b取反,但仍无法获得与p_val中相同的结果。

我正在使用的数据集(LIBSVM格式)是:

small_train.libsvm

0 0:-0.36 1:-0.91 2:-0.99 3:-0.57 4:-1.38 5:-1.54
1 0:-1.4 1:-1.9 2:0.09 3:0.29 4:-0.3 5:-1.3
1 0:-0.43 1:1.45 2:-0.68 3:-1.58 4:0.32 5:-0.14
1 0:-0.76 1:0.3 2:-0.57 3:-0.33 4:-1.5 5:1.84

small_test.libsvm

1 0:-0.97 1:-0.69 2:-0.96 3:1.05 4:0.02 5:0.64
0 0:-0.82 1:-0.17 2:-0.36 3:-1.99 4:-1.54 5:-0.31

w的值是否正确计算? p_val结果是要与之进行比较的正确值吗?

非常感谢您一如既往的帮助。

1 个答案:

答案 0 :(得分:0)

我设法通过更改获取值以匹配:

Ext.application({
    name: 'Fiddle',

    launch: function() {
        var me = this;

        var storeOil = Ext.create('Ext.data.Store', {
            fields: [{
                name: 'Filedate',
                type: 'date',
                dateFormat: 'Y-m-d'
            }, {
                name: 'Instrument',
                type: 'string'
            },
            {
                name: 'CumulativePl',
                type: 'float'
            }],
            data: [{
                "Filedate": '2018-08-10',
                "Instrument": 'crudeOil',
                "CumulativePl": 999
            }, {
                "Filedate": '2018-08-11',
                "Instrument": 'crudeOil',
                "CumulativePl": 1200
            }]
        });

        var storeGold = Ext.create('Ext.data.Store', {
            fields: [{
                name: 'Filedate',
                type: 'date',
                dateFormat: 'Y-m-d'
            }, {
                name: 'Instrument',
                type: 'string'
            },
            {
                name: 'CumulativePl',
                type: 'float'
            }],
            data: [{
                "Filedate": '2018-08-10',
                "Instrument": 'gold',
                "CumulativePl": 500
            }, {
                "Filedate": '2018-08-11',
                "Instrument": 'gold',
                "CumulativePl": 700
            }]
        });



        Ext.create('Ext.panel.Panel', {
            title: 'Hello',
            width: '100%',
            renderTo: Ext.getBody(),
            items: [{
                xtype: 'cartesian',
                width: '100%',
                height: 500,

                insetPadding: 40,
                innerPadding: {
                    left: 40,
                    right: 40
                },
                axes: [{
                    type: 'numeric',
                    fields: 'CumulativePl',
                    position: 'left',
                    grid: true,
                    minimum: 0
                }, {
                    type: 'time',
                    fields: 'Filedate',
                    position: 'bottom',
                    grid: true,
                    label: {
                        rotate: {
                            degrees: -40
                        }
                    }
                }],
                series: [{
                    type: 'line',
                    store: storeGold,
                    xField: 'Filedate',
                    yField: 'CumulativePl'

                },{
                    type: 'line',
                    store: storeOil,
                    xField: 'Filedate',
                    yField: 'CumulativePl'


                }]
            }]
        });
    }
});

m = svm_train(ytrain, xtrain, '-q')

通过查看文档,默认内核类型为非线性(径向函数)。设置线性内核后,结果现在看起来可以对齐。

以下是可用的内核类型:

m = svm_train(ytrain, xtrain, '-q -t 0')