我试图用L1正则化写一个逻辑回归。我使用对数似然的随机梯度上升作为成本函数。但我的函数总是给出接近0的值。我的数据集来自http://archive.ics.uci.edu/ml/datasets/HTRU2
以下是示例,其中y_hat是预测函数的输出,y [i]是目标的真值:
y_hat: [ 0.08707034] y[i]: 0
y_hat - y[i]: [ 0.08707034]
y_hat: [ 0.06406335] y[i]: 0
y_hat - y[i]: [ 0.06406335]
y_hat: [ 0.04818193] y[i]: 0
y_hat - y[i]: [ 0.04818193]
y_hat: [ 0.19760354] y[i]: 1
y_hat - y[i]: [-0.80239646]
y_hat: [ 0.08982549] y[i]: 0
y_hat - y[i]: [ 0.08982549]
y_hat: [ 0.0844086] y[i]: 1
y_hat - y[i]: [-0.9155914]
这是阈值为0.5时的准确度
threshold= 0.5
acc= 0.0159217877095
这是我的后勤sgd代码
def logreg_sgd(X, y, alpha = .001, iters = 100000, eps=1e-4):
n, d = X.shape
theta = numpy.zeros((d, 1))
k = 0
lam = 0.001
for k in range(iters):
i = k%n
x = X[i, :]
xT = numpy.transpose([x])
y_hat = sigmoid(x, theta)
beta = de_norm1(theta)
func_g = (y[i] - y_hat)*xT + lam*beta
theta_k = theta.copy()
theta = theta + alpha*func_g
for delta in abs(theta-theta_k):
if delta > eps:
break
return theta
def de_norm1(theta):
d, _ = theta.shape
beta = numpy.zeros((d, 1))
for i in range(d):
if theta[i,0] < 0:
beta[i,0] = -1
elif theta[i,0] > 0:
beta[i,0] = 1
return beta
def sigmoid(X, theta):
z = numpy.dot(X, theta)
value = 1.0/(1.0 + numpy.exp(-z))
return value
我无法弄清楚为什么这是错误的
以下是第二个编辑内容。
这是我画的roc_curve。
以及绘制roc cureve的代码
# plot the ROC curve of your prediction
# x aixes: TPR = TP / ( TP + FN )
# y aixes: FPR = FP / ( FP + TN )
def plot_roc_curve(X_test, y_true, theta):
k = 51
FPR_x = numpy.zeros(k)
TPR_y = numpy.zeros(k)
for n in range(k):
threshold = n/(k-1)
y_pred, FPR, TPR = predict(X_test, y_true, theta, threshold)
tn, fp, fn, tp = sklearn.metrics.confusion_matrix(y_true, y_pred).ravel()
FPR_x[n] = fp/(fp+tn)
TPR_y[n] = tp/(tp+fn)
print(FPR_x[n], TPR_y[n])
plt.xlabel('FPR')
plt.ylabel('TPR')
plt.plot(FPR_x, TPR_y, '-')
# plt.scatter(FPR_x, TPR_y, marker='o', color='blue')
plt.show()
这是预测(分类)功能
def predict(X, y_true, theta, threshold):
value = sigmoid(X, theta)
row, col = X.shape
y_hat = numpy.zeros(row)
P = sum(y_true)
N = row - P
FP = 0.0
TP = 0.0
print('P =', P)
for i, val in enumerate(value):
if val > threshold:
y_hat[i] = 1
TP+=1
else:
y_hat[i] = 0
if y_true[i] != y_hat[i]:
if y_true[i] == 0:
FP+=1
FPR = FP/N
TPR = TP/P
return y_hat, FPR, TPR
我的数据加载功能
def load_train_test_data(train_ratio=.8):
data = pandas.read_csv('./HTRU2/HTRU_2.csv', header=None)
X = data.iloc[:,:8]
X = numpy.concatenate((numpy.ones((len(X), 1)), X), axis=1)
y = data.iloc[:,8]
y = numpy.array(y)
return sklearn.model_selection.train_test_split(X, y, test_size = 1 - train_ratio, random_state=0)
比例功能
def scale_features(X_train, X_test, low=0, upp=1):
minmax_scaler = sklearn.preprocessing.MinMaxScaler(feature_range=(low, upp)).fit(numpy.vstack((X_train, X_test)))
X_train_scale = minmax_scaler.transform(X_train)
X_test_scale = minmax_scaler.transform(X_test)
return X_train_scale, X_test_scale
我的主要功能
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import numpy
import pandas
import sklearn.metrics
import sklearn.model_selection
import sklearn.linear_model
import sklearn.preprocessing
import matplotlib.pyplot as plt
def main(argv):
X_train, X_test, y_train, y_test = load_train_test_data(train_ratio=.8)
X_train_scale, X_test_scale = scale_features(X_train, X_test, 0, 1)
theta = logreg_sgd(X_train_scale, y_train)
print('theta:\n', theta)
plot_roc_curve(X_test, y_test, theta)
if __name__ == "__main__":
main(sys.argv)
示例数据:
140.5625,55.68378214,-0.234571412,-0.699648398,3.199832776,19.11042633,7.975531794,74.24222492,0
102.5078125,58.88243001,0.465318154,-0.515087909,1.677257525,14.86014572,10.57648674,127.3935796,0
103.015625,39.34164944,0.323328365,1.051164429,3.121237458,21.74466875,7.735822015,63.17190911,0
136.75,57.17844874,-0.068414638,-0.636238369,3.642976589,20.9592803,6.89649891,53.59366067,0
88.7265625,40.67222541,0.600866079,1.123491692,1.178929766,11.4687196,14.26957284,252.5673058,0
93.5703125,46.69811352,0.53190485,0.416721117,1.636287625,14.54507425,10.6217484,131.3940043,0
119.484375,48.76505927,0.03146022,-0.112167573,0.99916388,9.279612239,19.20623018,479.7565669,0
130.3828125,39.84405561,-0.158322759,0.389540448,1.220735786,14.37894124,13.53945602,198.2364565,0
107.25,52.62707834,0.452688025,0.170347382,2.331939799,14.48685311,9.001004441,107.9725056,0
107.2578125,39.49648839,0.465881961,1.162877124,4.079431438,24.98041798,7.397079948,57.78473789,0
142.078125,45.28807262,-0.320328426,0.283952506,5.376254181,29.00989748,6.076265849,37.83139335,0
133.2578125,44.05824378,-0.081059862,0.115361506,1.632107023,12.00780568,11.97206663,195.5434476,0
134.9609375,49.55432662,-0.135303833,-0.080469602,10.69648829,41.34204361,3.893934139,14.13120625,0
117.9453125,45.50657724,0.325437564,0.661459458,2.836120401,23.11834971,8.943211912,82.47559187,0
138.1796875,51.5244835,-0.031852329,0.046797173,6.330267559,31.57634673,5.155939859,26.14331017,0
114.3671875,51.94571552,-0.094498904,-0.287984087,2.738294314,17.19189079,9.050612454,96.61190318,0
109.640625,49.01765217,0.13763583,-0.256699775,1.508361204,12.07290134,13.36792556,223.4384192,0
100.8515625,51.74352161,0.393836792,-0.011240741,2.841137124,21.63577754,8.302241891,71.58436903,0
136.09375,51.69100464,-0.045908926,-0.271816393,9.342809365,38.09639955,4.345438138,18.67364854,0
99.3671875,41.57220208,1.547196967,4.154106043,27.55518395,61.71901588,2.20880796,3.662680136,1
100.890625,51.89039446,0.627486528,-0.026497802,3.883779264,23.04526673,6.953167635,52.27944038,0
105.4453125,41.13996851,0.142653801,0.320419676,3.551839465,20.75501684,7.739552295,68.51977061,0
95.8671875,42.05992212,0.326386917,0.803501794,1.83277592,12.24896949,11.249331,177.2307712,0
117.3671875,53.90861351,0.257953441,-0.405049077,6.018394649,24.76612335,4.807783224,25.52261561,0
106.6484375,56.36718209,0.378355072,-0.266371607,2.43645485,18.40537062,9.378659682,96.86022536,0
112.71875,50.3012701,0.279390953,-0.129010712,8.281772575,37.81001224,4.691826852,21.27620977,0
130.8515625,52.43285734,0.142596727,0.018885442,2.64632107,15.65443599,9.464164025,115.6731586,0
119.4375,52.87481531,-0.002549267,-0.460360287,2.365384615,16.49803188,9.008351898,94.75565692,0
123.2109375,51.07801208,0.179376819,-0.17728516,2.107023411,16.92177312,10.08033334,112.5585913,0
102.6171875,49.69235371,0.230438984,0.193325371,1.489130435,16.00441146,12.64653474,171.8329021,0
110.109375,41.31816988,0.094860398,0.68311261,1.010033445,13.02627521,14.66651082,231.2041363,0
99.9140625,43.91949797,0.475728501,0.781486196,0.619565217,9.440975862,20.1066391,475.680218,0
128.34375,52.17210664,-0.049280401,-0.208256987,2.173913043,12.9939472,9.965757364,141.5100843,0
142.0546875,53.87315957,-0.470772686,-0.125946417,4.423076923,27.08351266,6.681658306,45.94403008,0
121.1328125,47.6326062,0.177360308,0.024918111,2.151337793,20.55243738,9.920468181,99.74707919,0
102.328125,48.98040255,0.315729409,-0.202183315,1.898829431,13.83904002,11.61993869,172.1303732,0
147.8359375,53.62263651,-0.131079596,-0.288851172,2.692307692,17.08088101,8.849177975,92.20174502,0
108.0390625,34.91024257,0.321156562,1.821631493,3.899665552,23.72205203,7.506209958,60.88691267,0
107.875,37.33065932,0.49600476,1.481815856,1.173913043,12.01691346,14.53428973,252.6947381,0
118.84375,45.9319193,-0.109242666,0.137683548,2.33277592,14.71602871,9.634175054,118.6696797,0
138.4609375,48.91716569,-0.039591916,-0.176243068,2.443143813,18.3133067,8.672894053,83.06924213,0
116.203125,47.34586165,0.211946824,-0.022177703,3.606187291,18.94498977,7.035644684,59.23122572,0
120.5546875,45.54990543,0.282923998,0.419908714,1.358695652,13.07903424,13.31214143,212.5970294,1
121.8828125,53.04267461,0.200520721,-0.282219034,2.116220736,16.58087621,8.947602793,91.01176155,0
125.2109375,51.17519729,0.139851288,-0.385736754,1.147993311,12.41401211,14.06879728,228.1315536,0
107.90625,48.08414459,0.460846577,0.29651005,1.993311037,13.84106954,9.969395408,128.7447168,0
106.28125,43.02178545,0.408868006,1.032014666,1.610367893,17.25115554,12.11019331,152.0149562,0
106.3359375,45.05002035,0.418645099,0.603995884,1.200668896,12.38856143,13.30645184,209.41199,0
125.734375,52.65772207,0.026516673,-0.429632907,4.850334448,29.93619483,6.361837308,40.25501275,0
113.546875,49.50029346,0.130001201,-0.202839025,2.407190635,14.42205142,9.310343318,113.6874714,0
134.0390625,51.80045885,-0.195844789,-0.396816077,1.107859532,13.23858397,13.77580037,208.4202575,0
105.1171875,45.09202762,0.464847891,0.878058377,4.283444816,23.96731526,6.562543005,46.66728734,0
95.328125,44.66789069,0.386495074,0.755115427,2.694816054,17.9985973,9.094177089,97.80243629,0
119.3359375,47.506953,0.220316758,0.645717725,0.79264214,9.540907141,18.76653977,441.5133427,0
136.1875,51.95291588,-0.070431774,-0.482219687,0.849498328,9.677531027,18.73655411,431.3904454,0
112.859375,55.10625168,0.174766173,-0.404019163,3.032608696,19.69431374,7.266252257,58.03777067,0
108.625,52.74614915,0.453556415,0.069731528,2.304347826,16.18365586,9.780440566,114.9993838,0
113.953125,49.2214161,0.234723211,0.289792216,1.081103679,13.48209307,14.25608113,216.8362204,0
141.96875,50.47089779,0.244974491,-0.342664657,2.823578595,16.23818776,8.207743613,85.53258352,0
136.5,49.9327673,0.044623267,-0.374311456,1.555183946,12.81353792,13.31433912,214.813089,0
83.6796875,36.37928102,0.572531753,2.66461052,4.0409699,23.16912864,7.006681423,53.51400467,0
27.765625,28.66604164,5.770087392,37.4190088,73.11287625,62.07021971,1.268206006,1.082920221,1
135.859375,51.93727202,0.065768774,-0.366114187,20.77424749,52.77264803,2.730908619,6.607439551,0
112.09375,48.81156969,0.418565459,0.350156301,2.204013378,17.37868175,9.520551079,100.7875964,0
126.8671875,53.1293191,0.13633915,-0.588709439,1.149665552,13.96514443,13.23049959,186.2685104,0
117.5390625,47.73296528,0.173139263,-0.150653604,1.060200669,14.28934355,14.17637248,208.2780851,0
143.0859375,49.92197464,-0.157561213,-0.153332697,3.563545151,21.28808157,7.337117054,59.16844081,0
101.296875,39.43395574,0.390053688,1.551969375,4.925585284,26.32242163,6.086053659,39.11620774,0
119.8984375,53.82550508,0.143378486,-0.528427658,4.04180602,24.57913147,6.581293412,44.89951492,0
123.125,50.33124651,-0.087091427,0.087932382,1.280936455,10.68864639,14.63669101,288.668932,0
102.046875,48.79050551,0.45222638,0.272447732,2.37541806,13.9284014,9.127499454,116.0232222,0
119.4453125,53.14305702,0.012830273,-0.378955989,2.932274247,17.9297569,8.289888515,81.34651657,0
128.515625,54.94585181,-0.012552759,-0.658278628,2.891304348,17.75294666,8.913745414,94.08210337,0
128.15625,46.89690113,-0.179233074,-0.005819915,4.193979933,22.25815766,6.451755484,46.48663173,0
115.6171875,40.29037592,0.110702345,0.513224267,11.63963211,39.95655753,3.640288988,12.68457562,0
136.7421875,44.39123754,-0.22192524,0.908084632,2.105351171,14.49837742,10.13157115,128.3951486,0
135.265625,48.14390609,0.015920939,-0.15877212,8.539297659,31.13487695,4.082788387,17.27267344,0
113.9609375,52.24736871,0.127976811,-0.457499415,4.407190635,26.29776588,6.709564866,47.4057088,0
107.796875,45.6803362,0.655279783,0.954879021,1.7090301,15.1907807,11.52025038,150.3053634,0
124.5,57.35361802,-0.014849043,-0.550963937,4.783444816,27.50164045,6.090448645,37.81809112,0
119.296875,46.45417086,0.202629139,0.12837064,3.748327759,18.8510099,6.414682286,50.85055687,0
148.3828125,51.200757,-0.113195798,-0.50223559,1.408026756,12.08791939,12.5121354,201.1278905,0
109.4921875,53.2901838,0.2528458,-0.319022964,4.132943144,25.89210734,6.741542034,46.83080307,0
112.125,46.30840906,0.721646098,0.612454163,1.173076923,11.04918969,14.6307442,273.2509626,0
128.7734375,45.80669555,0.086169154,-0.031764808,2.66722408,15.93295829,8.75667197,95.36727143,0
140.265625,48.93721813,0.03252958,0.119064502,2.315217391,19.87317992,9.67260138,98.89698457,0
87.515625,51.76343189,1.070588903,0.74283956,15.67809365,50.90591579,3.141187931,8.440045483,0
132.140625,42.09582342,0.143191723,0.876730035,1.863712375,13.26595667,10.25798651,140.0407088,0
104.078125,45.24078107,0.532040422,0.743853067,1.43645485,15.41478275,11.89911604,150.9872549,0
122.6015625,53.79697654,-0.051964773,-0.379729027,2.636287625,15.17095406,9.519292364,117.7422254,0
114.28125,41.25396525,0.41182113,0.616996141,2.412207358,20.42794216,9.198391753,88.37057957,0
112.4375,38.2956733,0.501943444,1.07484029,2.81270903,18.13688307,7.859968426,71.29944944,0
23.625,29.94865398,5.688038235,35.98717152,146.5685619,82.39462399,-0.274901598,-1.121848281,1
94.5859375,35.77982308,1.187308683,3.68746932,6.071070234,29.76039993,5.318766827,28.69804799,1
137.2421875,46.45474042,0.045257133,-0.438857507,59.4958194,77.75535652,0.71974817,-1.183162032,0
123.53125,53.34878418,0.072077648,-0.071600995,0.781772575,10.57083301,17.11829958,339.6608262,0
70.0234375,35.28067478,1.157657193,4.546692371,3.003344482,19.57538355,7.954436097,71.96015886,0
129.375,44.56841651,0.049779493,0.506330188,3.60451505,21.13303805,7.181384025,56.85662961,0
97.140625,47.77089438,0.625218075,0.740796144,4.193143813,26.46526062,6.927045631,49.62852693,0
101.96875,46.31632702,0.439814307,0.294261355,1.748327759,16.4866229,10.8103928,127.7333664,0
答案 0 :(得分:1)
我将使用答案,因为在评论中我没有足够的空间。我希望能指出你正确的方向。您可以使用sklearn获取值以检查模型。我下载了您的数据集并创建了(以快速和肮脏的方式)平衡数据集
X_train, X_test, y_train, y_test =
sklearn.model_selection.train_test_split(data, labels,
test_size = 1 - 0.8, random_state=0)
X_train_bal1 = X_train[y_train == 1]
Y_train1 = y_train[y_train == 1]
X_train_bal0 = X_train[y_train == 0].sample(len(X_train_bal1))
Y_train0 = y_train[y_train == 0].sample(len(X_train_bal1))
X = pd.concat([X_train_bal1, X_train_bal0])
Y = pd.concat([Y_train1, Y_train0])
然后我使用sklearn进行逻辑回归
logisticb = LogisticRegression()
logisticb.fit(X,Y)
logisticb.score(X,Y)
现在你有1365个观察标记为1和1365个观察标记为0.你得到0.943的分数。非常好。在原始数据集中,您处理大约1%的标记为1的观察结果,大约99%标记为0.我正在检查您的代码,因为我正在键入此内容。我会在发现错误后立即更新此答案。
编辑:我检查了你的代码,但我必须承认我只是简化它,我会重写它。对于您想要做的事情(逻辑回归)过于复杂且难以调试。我认为你最好投入时间调试代码所需的时间。
从积极的方面来说,我赞赏你试图从头开始实施逻辑回归。这是一次很棒的学习经历。你正在做很多事情(比如缩放输入X)。你应该尝试的是尽可能多地使用numpy功能来代码化你的代码。你的代码很慢。你可以把它变得非常快。如果它可以帮助你,我从头开始编写完整的逻辑回归实现,可以从github(https://github.com/michelucci/Logistic-Regression-Explained/blob/master/MNIST%20with%20Logistic%20Regression%20from%20scratch.ipynb)获得。也许它可以帮助你。
祝你好运,一切顺利,翁贝托