sklearn在尝试预测数字时始终预测1

时间:2017-11-22 02:29:30

标签: python machine-learning scikit-learn blender scikits

我试图编写预测混合器中曲线数字的代码。 所以我将曲线转换为sklearn使用的矩阵并尝试预测数字,不幸的是,无论我做什么,预测总是1。

2d矩阵(它看起来像我在搅拌机中的圆圈):

[[  0.   0.   0.   0.   0.   0.   0.   0.]
 [  0.   0.   0.  25.  25.   0.   0.   0.]
 [  0.  25.  25.  25.   0.  25.  25.   0.]
 [  0.  25.   0.   0.   0.   0.  25.   0.]
 [  0.  25.   0.   0.   0.   0.  25.   0.]
 [  0.  25.   0.   0.   0.   0.  25.   0.]
 [  0.   0.  25.  25.  25.  25.   0.   0.]
 [  0.   0.   0.   0.   0.   0.   0.   0.]]

代码:

import bpy
import numpy as np
from sklearn import datasets
from sklearn import svm
import scipy.misc

ob = bpy.context.object
assert ob.type == 'CURVE' # throw error if it's not a curve
curve = ob.data
spline = curve.splines.active # let's assume there's only one
assert spline.type == 'BEZIER' # throw error if it's not a bezier

shortest = None
shortestDist = 10000
shortest_x = None
shortestDist_x = 10000
result = []
for point in spline.bezier_points:
    dist = point.co.y
    dist_x = point.co.x
    if dist < shortestDist : #test if better so far
        shortest = point
        shortestDist = dist   
    if dist_x < shortestDist_x : #test if better so far
        shortest_x = point
        shortestDist_x = dist  

print(1 / abs(shortest.co.y))
result.append([shortest, shortestDist, dist, dist_x])
mult_y = 1 / abs(shortest.co.y)
mult_x = 1 / abs(shortest_x.co.x)
point_pos = []
for point in spline.bezier_points:
    loc = point.co.y
    loc_x = point.co.x
    max_y = loc * mult_y
    max_x = loc_x * mult_x
    point_pos.append([loc, loc_x])

matrix = np.zeros((8, 8))
pixel = []

for index in enumerate(matrix):
    matrix_to_co_y = 1 / len(matrix) * index[0]
    for index_y in enumerate(matrix[index[0]]):
        matrix_to_co_x = 1 / len(matrix) * index_y[0]
        #print(matrix_to_co_y)
        for point in point_pos:
            if matrix_to_co_y > point[0] > matrix_to_co_y - 1 / len(matrix):
                if matrix_to_co_x > point[1] > matrix_to_co_x - 1 / len(matrix):
                    pixel.append([index[0], index_y[0]])

for p in enumerate(pixel):
    matrix[p[1][0]][p[1][1]] = 25

flat = np.ravel(matrix)


digits = datasets.load_digits()

clf = svm.SVC(gamma=0.001, C=100)

x,y = digits.data[:-1], digits.target[:-1]
clf.fit(x,y)
print('Prediction:',clf.predict([flat]))

print(matrix)

我不知道自己做错了什么。 任何帮助将不胜感激

2 个答案:

答案 0 :(得分:0)

这可能是输入图像或分类器的问题。 要测试问题所在,

1)尝试使用多个输入图像。尝试为每个数字制作一个图像,0-9。如果您的分类器为所有这些分类器预测“1”,则问题可能出在分类器中。但如果它可以预测其中一些,那么很可能只是你的单输入图像造成了麻烦。

2)尝试使用其他分类器。几乎任何东西都可以在digits数据集上为您提供良好的性能。我尝试使用RandomForestClassifier,它正确地将您的图像预测为“0”。

概念证明:

import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn import datasets
my_input = np.array(
 [[  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.],
 [  0.,   0.,   0.,  25.,  25.,   0.,   0.,   0.],
 [  0.,  25.,  25.,  25.,   0.,  25.,  25.,   0.],
 [  0.,  25.,   0.,   0.,   0.,   0.,  25.,   0.],
 [  0.,  25.,   0.,   0.,   0.,   0.,  25.,   0.],
 [  0.,  25.,   0.,   0.,   0.,   0.,  25.,   0.],
 [  0.,   0.,  25.,  25.,  25.,  25.,   0.,   0.],
 [  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.]])
iris = datasets.load_iris()
digits = datasets.load_digits()
clf = RandomForestClassifier()
clf.fit(digits.data, digits.target)
clf.predict(my_input.reshape(1, -1))
# Outputs array([0])

答案 1 :(得分:0)

您应验证预测的概率是否高于默认阈值。在这种情况下,您将总是找到1作为预期的班级。为了验证概率值,您可以运行以下代码,因为您的测试功能在代码中被标识为:flat。

clf = svm.SVC(gamma=0.001, C=100) # This line of code is from your post
x,y = digits.data[:-1], digits.target[:-1] # This line of code is from your post
clf.fit(x,y)  # This line of code is from your post
y_pred=svc.predict_proba(flat) # Here, I predict the probabilities, using the test data you have named flat.

# The predicted probabilities are printed bellow
print(y_pred)

当然,您已经浏览了上面代码印刷的预测概率值y_pred。如果所有这些概率均大于0.5(这是二进制分类的默认阈值),则应使用下面的代码,并将阈值更改为高于上面预测的概率最小值的值。例如,假设概率的最小值为0.55,则阈值应高于0.55。我选择0.6。但是,如果0.6大于概率的最大值,则

threshold=0.6    
ypred=(y_pred[:,1]>threshold).astype('int') 
print(ypred)

您可以尝试阈值的多个值,并测试哪个值可以生成您感兴趣的最佳指标(准确性得分,查全率,精确度等)。