我是机器学习的新手。我在大学修过ML课程,作为一项家庭作业,我有使用朴素贝叶斯分类器对cifar10进行分类的任务。我已经尝试了两天才能完成这项作业,但是总有些错误。老师告诉我们,分类准确度应该在60-70%之间。我的分类精度为20%。作为特征,我们具有每个通道的平均值(红色平均值,绿色平均值,蓝色平均值)。因此,我们拥有大小为50000x3的训练数据矩阵。我按类别对矩阵的条目进行分组,并为每个矩阵找到了mu和sigma。这是代码:
import numpy as np
import math
from ex2_class import DataMatrix
import matplotlib.pyplot as plt
from keras.datasets import cifar10
def cifar_10_evaluate(pred, gt):
correct = 0
for i in range(len(pred)):
if pred[i] == gt[i]:
correct +=1
return (correct/len(pred))*100
def normpdf(x,mu,var):
nom = math.exp(-((x-mu)**2)/(2*var))
denom = math.sqrt(2*math.pi*var)
return nom/denom
def cifar_10_features(x):
r = x[0:1024]
g = x[1024:2048]
b = x[2048:3072]
return [sum(r)/1024, sum(g)/1024, sum(b)/1024]
def cifar_10_bayes_learn(F,label):
ms = []
sg = []
for i in range(0,10):
idx = np.where(label==i) #get from label-vector ith entries indecies
class_i = F[idx] #get data according to gotten indecies
#calculate mean of each component
mur = np.mean(class_i[:,0])
mug = np.mean(class_i[:,1])
mub = np.mean(class_i[:,2])
ms.append([mur,mug,mub])
#calculate sigma squared of each component
sg_r = np.var(class_i[:,0])
sg_g = np.var(class_i[:,1])
sg_b = np.var(class_i[:,2])
sg.append([sg_r,sg_g,sg_b])
p = len(class_i)/len(F)
return [np.array(ms),np.array(sg),1/10]
def cifar_10_classify(f,ms,gs,p):
test = cifar_10_features(f)
prob = {}
for i in range(0,10):
pr = normpdf(test[0],ms[i,0],gs[i,0])
pg = normpdf(test[1],ms[i,1],gs[i,1])
pb = normpdf(test[2],ms[i,2],gs[i,2])
prob[i] = pr*pg*pb*p
mx = max(prob.values())
cl = [c for c,nd in prob.items() if mx == nd]
return cl[0]
def save(mus, sigmas):
np.save('class_mus',mus)
np.save('class_variances',sigmas)
def main():
mydata = DataMatrix(path_to_data)
T = mydata.get_training_data() #training data
L = mydata.get_training_labels()
C = mydata.get_test_data()
CL = mydata.get_test_labels()
C = C[0:10000,:]
CL = CL[0:10000]
tr_set = np.load('training_mus.npy')
[ms, gs, p] = cifar_10_bayes_learn(tr_set, L)
prob = []
for i in range(len(C)):
prob.append(cifar_10_classify(C[i],ms,gs,p))
print(cifar_10_evaluate(prob,CL))
请帮助我解决问题,因为两天后我仍然没有发现我的代码有什么问题。谢谢