在3D中绘制梯度下降-等高线图

时间:2020-02-14 00:26:03

标签: matplotlib multidimensional-array linear-regression gradient

我已经生成了3个参数以及cost函数,我有Theta列表和100个迭代中100个值的cost列表。我想在3d中绘制相对于成本的最后两个参数,以可视化等高线图和谷物碗功能上的水平集。

带有3个参数[1's,bedrooms,Sq.ft]的房屋数据集https://drive.google.com/open?id=13v8ijuzbj8Z-taGK_4D37P008ML-DIZk,以预测价格的形状(100000,3)和y(100000,)。目的是研究3d中的谷物碗功能,并观察梯度如何收敛

参考文献:Gradient descent impementation python - contour lines

def compute_cost(X, y, theta):
    return np.sum(np.square(np.matmul(X, theta) - y)) / (2 * len(y))

def gradient_descent_multi(X, y, theta, alpha, iterations):
    theta = np.zeros(X.shape[1])
    m = len(X)
    j_history = np.zeros(iterations)
    theta_1_hist = [] 
    theta_2_hist = []
    for i in range(iterations):


        gradient = (1/m) * np.matmul(X.T, np.matmul(X, theta) - y)

        theta = theta - alpha * gradient

        j_history[i] = compute_cost(X,y,theta)
        theta_1_hist.append(theta[1])
        theta_2_hist.append(theta[2])


#         J_history.append(compute_cost(X,y,theta))
#         print(J_history)



#         grad_plot.append(theta)

    return theta ,j_history, theta_1_hist, theta_2_hist

theta = np.zeros(2)
alpha = 0.1
iterations = 100

#Computing the gradient descent
theta_result,J_history, theta_0, theta_1 = gradient_descent_multi(X,y,theta,alpha,iterations)

Theta 1:
[15.651431183495157,
 28.502297542920118,
 39.0665487784193,
 ...
 105.78644212297141,
 105.882701389551,
 105.97741737336399]
Theta 2:
[14.713094556818124,
 26.640668175454184,
 36.29642936488919,
 ....
 59.1710519900493,
 59.07633606136845]
Cost array: 
array([185814.55027215, 149566.02825652, 120605.70700938,  97414.66187874,
            78807.39414333,  63853.50250138,  51819.24085843,  42123.5122655 ,
            34304.44290442,  27993.78459818,  22897.16477958,  18778.74417703,
           ....
             1257.38095357,   1257.13475353,   1256.89643143,   1256.66572779,
             1256.44239308,   1256.22618706,   1256.01687827,   1255.81424349,
             1255.61806734,   1255.42814185,   1255.24426618,   1255.06624625])

0 个答案:

没有答案