在tensorflow中更新变量值

时间:2016-12-01 20:23:56

标签: tensorflow

我有一个关于通过tensorflow python api更新张量值的基本问题。

考虑代码段:

x = tf.placeholder(shape=(None,10), ... )
y = tf.placeholder(shape=(None,), ... )
W = tf.Variable( randn(10,10), dtype=tf.float32 )
yhat = tf.matmul(x, W)

现在让我们假设我想实现某种迭代更新W值的算法(例如一些优化算法)。这将涉及以下步骤:

for i in range(max_its):
     resid = y_hat - y
     W = f(W , resid) # some update 

这里的问题是LHS上的W是一个新的张量,而不是W中使用的yhat = tf.matmul(x, W)!也就是说,创建了一个新变量,并且我的“模型”中使用的W值不会更新。

现在解决这个问题的方法是

 for i in range(max_its):
     resid = y_hat - y
     W = f(W , resid) # some update 
     yhat = tf.matmul( x, W)

导致为我的循环的每次迭代创建一个新的“模型”!

有没有更好的方法来实现这个(在python中)而不为循环的每次迭代创建一大堆新模型 - 而是更新原始张量W“就地”可以这么说?

3 个答案:

答案 0 :(得分:7)

变量有一个assign方法。尝试:W.assign(f(W,resid))

答案 1 :(得分:2)

@ aarbelle的简洁回答是正确的,如果有人需要更多信息,我会稍微扩展一下。下面的最后两行用于更新W。

x = tf.placeholder(shape=(None,10), ... )
y = tf.placeholder(shape=(None,), ... )
W = tf.Variable(randn(10,10), dtype=tf.float32 )
yhat = tf.matmul(x, W)

...

for i in range(max_its):
    resid = y_hat - y
    update = W.assign(f(W , resid)) # do not forget to initialize tf variables. 
    # "update" above is just a tf op, you need to run the op to update W.
    sess.run(update)

答案 2 :(得分:0)

精确地,答案应该是sess.run(W.assign(f(W,resid)))。然后使用sess.run(W)显示更改。