张量流中的变量更新问题

时间:2017-10-24 22:20:10

标签: tensorflow

我的问题是为什么在下面的简单代码中,某些变量(例如w_3)的值不会更新,但对于其他变量,它们会更新。

import tensorflow as tf
import numpy as np
x=([1, 2, 3])
x= np.array(x)

sess= tf.InteractiveSession()

input_data= tf.placeholder(dtype= 'float32', shape= (None))

w_1= tf.Variable(tf.truncated_normal([1], stddev= 0.01), trainable= True, name='w_1')

w_2= tf.Variable(tf.truncated_normal([1], stddev= 0.01), trainable= True, name='w_2')

w_3= tf.Variable(tf.truncated_normal([1], stddev= 0.01), trainable= True, name='w_3')

loss= tf.pow(w_1, 2)- input_data+ tf.pow(w_2, 2)+ tf.pow(w_1, 2)

optimizer = tf.train.GradientDescentOptimizer(learning_rate= 0.01)

train_op = optimizer.minimize(loss)

init= tf.global_variables_initializer()

sess.run(init)

for j in range(0,4):

    for i in range(0,3):
        sess.run(train_op, feed_dict={input_data: x[i]})
        print('w1:',sess.run(w_1, feed_dict={input_data: x[i]}))
        print('w2:',sess.run(w_2, feed_dict={input_data: x[i]}))
        print('w3:',sess.run(w_3, feed_dict={input_data: x[i]}))

1 个答案:

答案 0 :(得分:1)

这是预期的:您的w_3变量不参与损失计算。因此,渐变不依赖于它,并且w_3变量不会更新!

也许你打算使用w_3并制作一个简单但典型的拼写错误!