我正在尝试从头开始实现的最佳神经网络优化算法[1]中的预测的可视化比较。
亚当的损失是:0.6931
亚当的结果是否正确?
这就是我得到的图形:
np.random.seed(42)
w = np.array([0, 0, 0, 0, 0, 1])
eta = 0.05 # learning rate
alpha = 0.9 # momentum
nu = np.zeros_like(w)
n_iter = 100
batch_size = 4
loss = np.zeros(n_iter)
plt.figure(figsize=(12, 5))
for i in range(n_iter):
ind = np.random.choice(X_expanded.shape[0], batch_size)
loss[i] = compute_loss(X_expanded, y, w)
if i % 10 == 0:
visualize(X_expanded[ind, :], y[ind], w, loss)
grad = compute_grad(X_expanded, y, w)
nu = alpha * nu + eta * grad
w = w - nu
visualize(X, y, w, loss)
plt.clf()
np.random.seed(42)
w = np.array([0, 0, 0, 0, 0, 1.])
eta = 0.1 # learning rate
alpha = 0.9 # moving average of gradient norm squared
g2 = np.zeros_like(w)
eps = 1e-8
n_iter = 100
batch_size = 4
loss = np.zeros(n_iter)
plt.figure(figsize=(12,5))
for i in range(n_iter):
ind = np.random.choice(X_expanded.shape[0], batch_size)
loss[i] = compute_loss(X_expanded, y, w)
if i % 10 == 0:
visualize(X_expanded[ind, :], y[ind], w, loss)
grad = compute_grad(X_expanded, y, w)
grad2 = grad ** 2
g2 = alpha * g2 + (1-alpha) * grad2
w = w - eta * grad / np.sqrt(g2 + eps)
visualize(X, y, w, loss)
plt.clf()
np.random.seed(42)
w = np.array([0, 0, 0, 0, 0, 1.])
eta = 0.01 # learning rate
beta1 = 0.9 # moving average of gradient norm
beta2 = 0.999 # moving average of gradient norm squared
m = np.zeros_like(w) # Initial 1st moment estimates
nu = np.zeros_like(w) # Initial 2nd moment estimates
eps = 1e-8 # A small constant for numerical stability
n_iter = 100
batch_size = 4
loss = np.zeros(n_iter)
plt.figure(figsize=(12,5))
for i in range(n_iter):
ind = np.random.choice(X_expanded.shape[0], batch_size)
loss[i] = compute_loss(X_expanded, y, w)
if i % 10 == 0:
visualize(X_expanded[ind, :], y[ind], w, loss)
grad = compute_grad(X_expanded, y, w)
grad2 = grad ** 2
m = ((beta1 * m) + ((1 - beta1) * grad)) / (1 - beta1)
nu = ((beta2 * nu) + ((1 - beta2) * grad2)) / (1 - beta2)
w = (w - eta * m) / (np.sqrt(nu) + eps)
visualize(X, y, w, loss)
plt.clf()
我希望亚当的成本降低。我的意思是比RMSprop(0.2075)少。