关于铰链损失SVM Python实现的梯度下降

时间:2016-10-16 12:58:27

标签: python machine-learning svm gradient-descent

我正在尝试实现梯度下降算法,以最小化SVM铰链损失的目标。我试图实现的等式是 Hinge Loss

和max函数由子梯度技术处理,如下所示。 Differentiation of max function

问题是我无法获得适当的收敛。每次运行代码时,权重都会变得不同。

下面是python实现

import sys;
import random;

## Helper function
## Dot_product : calculates dot product
## of two vectors in argument

def dot_product(a,b):
        result = [];
        for col in range(0,len(a),1):
            t = a[col]*b[col];
            result.append(t);
        return sum(result)

dataset=[];
dataset = [[1,1], [1,2], [1,3], [3,1], [3,2], [3,3], [50,2]]
label_dictionary ={'2': -1.0, '1': -1.0, '4': 1.0, '3': 1.0, '0': -1.0, '5': 1.0, '6': 1.0}

cols = len(dataset[0]);
rows = len(dataset);

w = []
for i in range(0, cols, 1):
    w.append(random.uniform(-1, 1))
print('W = ', w)

# gradient descent
eta = 0.01;
error = 2;
iterations = 1;
difference = 1;

delf = [0]*cols;
#for i in range(iterations):

while(difference > 0.001):
    #print('Starting gradient descent.');
    for row_index in range(rows):
        if str(row_index) in label_dictionary:

            # calculate delf
            xi = dataset[row_index];
            yi = label_dictionary.get(str(row_index))

            dp = dot_product(w,xi); # w*xi
            condition = dp*yi;

            # Sub gradient. Diff of max.
            if(condition < 1):
                for col in range(cols):
                    delf[col] +=  -1*xi[col]*yi;
            elif(condition>=1):
                delf = [0]*cols;

            # Update
            for j in range(0, cols, 1):
                w[j] = w[j] - eta * delf[j];

            # Compute error
            #print('W',w);
            prev = error;
            error = 0;
            for i in range(0, rows, 1):
                dp = dot_product(w, dataset[i]);
                if str(i) in label_dictionary:
                    yi = int(label_dictionary[str(i)]);
                    error += (yi - dp) ** 2

            difference = prev - error;
            print('Error',difference)

print('W ',w)


# Predictions
# row labels that are not in
# dictionary

## Start Predictions

## if rowindex is not in labels dictionary
## that is our test sample

for index,x in enumerate(dataset):

    if(str(index) not in label_dictionary):
        #print('Testing data found', index, x)
        dp = dot_product(w,x);
        if(dp > 0):
            print(1,index);
        else:
            print(0,index);

0 个答案:

没有答案