最小化Tensorflow中的功能

时间:2018-11-30 00:00:59

标签: python tensorflow gradients

如何使用tf.gradients获得函数的渐变?当我使用GradientDescentOptimizer.minimize()时,以下内容正在工作,tf.gradients似乎在x ^ 2 + 2的导数为2x时求值为1

我想念什么?

x = tf.Variable(1.0, trainable=True)
y = x**2 + 2

grad = tf.gradients(y, x)
#grad = tf.train.GradientDescentOptimizer(0.1).minimize(y)

init = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init)
    grad_value = sess.run(grad)
    print(grad_value)

1 个答案:

答案 0 :(得分:3)

如果我正确理解了您的问题,则希望找到使x最小的x^2 + 2的值。

为此,您需要重复调​​用GradientDescentOptimizer,直到x收敛到使函数最小化的值为止。这是因为梯度下降是一种迭代技术。

此外,在张量流中,minimize的方法GradientDescentOptimizer会计算梯度,然后将其应用于相关变量(在您的情况下为x)。因此,代码应如下所示(注意,我注释了grad变量,除非您要查看渐变值,否则该变量不是必需的):

x = tf.Variable(1.0, trainable=True)
y = x**2 + 2

# grad = tf.gradients(y, x)
grad_op = tf.train.GradientDescentOptimizer(0.2).minimize(y)

init = tf.global_variables_initializer()

n_iterations = 10
with tf.Session() as sess:
    sess.run(init)
    for i in range(n_iterations):
        _, new_x = sess.run([grad_op, x])
        print('Iteration:', i,', x:', new_x)

您将得到:

Iteration: 0 , x: 1.0
Iteration: 1 , x: 0.6
Iteration: 2 , x: 0.36
Iteration: 3 , x: 0.216
Iteration: 4 , x: 0.07776
Iteration: 5 , x: 0.07776
Iteration: 6 , x: 0.046656
Iteration: 7 , x: 0.01679616
Iteration: 8 , x: 0.010077696
Iteration: 9 , x: 0.010077696

在收敛为0的真实答案时会看到。

如果将GradientDescentOptimizer的学习率从0.2提高到0.4,它将更快地收敛到0。

编辑

好的,根据我对问题的新理解,您无法执行x = x - alpha * gradient,因为这是python操作,它只是替换了对象x,因此您无法进行x.assign。您需要告诉tensorflow将op添加到图中,这可以使用x = tf.Variable(1.0, trainable=True) y = x**2 + 2 grad = tf.gradients(y, x) # grad_op = tf.train.GradientDescentOptimizer(0.5).minimize(y) update_op = x.assign(x - 0.2*grad[0]) init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) for i in range(10): new_x = sess.run([update_op, x]) print('Iteration:', i,', x:', new_x) 完成。它看起来像:

GradientDescentOptimizer

,我们得到与本地Iteration: 0 , x: 1.0 Iteration: 1 , x: 0.6 Iteration: 2 , x: 0.36 Iteration: 3 , x: 0.1296 Iteration: 4 , x: 0.1296 Iteration: 5 , x: 0.077759996 Iteration: 6 , x: 0.046655998 Iteration: 7 , x: 0.027993599 Iteration: 8 , x: 0.01679616 Iteration: 9 , x: 0.010077696 相同的答案:

{{1}}