我从简单实现单变量线性梯度下降开始,但不知道将其扩展到多变量随机梯度下降算法?
单变量线性回归
$config = Array(
'protocol' => 'smtp',
'smtp_host' => 'ssl://smtp.googlemail.com',
'smtp_port' => 465,
'smtp_user' => 'saurabh@gmail.com',
'smtp_pass' => 'my password',
'mailtype' => 'html',
'charset' => 'iso-8859-1'
);
$this->load->library('email', $config);
$this->email->set_newline("\r\n");
$result = $this->email->send();
答案 0 :(得分:8)
你的问题中有两个部分:
要获得更高的尺寸设置,您可以定义线性问题y = <x, w>
。然后,您只需要更改变量W
的维度以匹配w
的维度,并将乘法W*x_data
替换为标量积tf.matmul(x_data, W)
,并且您的代码应该运行很好。
要将学习方法更改为随机梯度下降,您需要使用tf.placeholder
抽象出成本函数的输入。
定义X
和y_
以在每个步骤保留输入后,您可以构建相同的成本函数。然后,您需要通过提供适当的小批量数据来调用您的步骤。
以下是如何实施此类行为的示例,它应显示W
快速收敛到w
。
import tensorflow as tf
import numpy as np
# Define dimensions
d = 10 # Size of the parameter space
N = 1000 # Number of data sample
# create random data
w = .5*np.ones(d)
x_data = np.random.random((N, d)).astype(np.float32)
y_data = x_data.dot(w).reshape((-1, 1))
# Define placeholders to feed mini_batches
X = tf.placeholder(tf.float32, shape=[None, d], name='X')
y_ = tf.placeholder(tf.float32, shape=[None, 1], name='y')
# Find values for W that compute y_data = <x, W>
W = tf.Variable(tf.random_uniform([d, 1], -1.0, 1.0))
y = tf.matmul(X, W, name='y_pred')
# Minimize the mean squared errors.
loss = tf.reduce_mean(tf.square(y_ - y))
optimizer = tf.train.GradientDescentOptimizer(0.01)
train = optimizer.minimize(loss)
# Before starting, initialize the variables
init = tf.initialize_all_variables()
# Launch the graph.
sess = tf.Session()
sess.run(init)
# Fit the line.
mini_batch_size = 100
n_batch = N // mini_batch_size + (N % mini_batch_size != 0)
for step in range(2001):
i_batch = (step % n_batch)*mini_batch_size
batch = x_data[i_batch:i_batch+mini_batch_size], y_data[i_batch:i_batch+mini_batch_size]
sess.run(train, feed_dict={X: batch[0], y_: batch[1]})
if step % 200 == 0:
print(step, sess.run(W))
两个旁注:
下面的实现称为小批量梯度下降,因为每个步骤都使用大小为mini_batch_size
的数据子集计算梯度。这是随机梯度下降的变体,通常用于稳定每个步骤的梯度估计。可以通过设置mini_batch_size = 1
。
数据集可以在每个时期进行随机播放,以使实现更接近理论考虑。最近的一些工作还考虑只使用一次通过数据集,因为它可以防止过度拟合。有关更详细和详细的解释,您可以看到Bottou12。这可以根据您的问题设置和您正在寻找的统计属性轻松更改。