在循环内修改Tensorflow变量

时间:2018-07-19 09:43:00

标签: python-3.x tensorflow while-loop

我想在while循环内修改变量的某些索引。 基本上将下面的python代码转换为Tensorflow:

import numpy
tf_variable=numpy.zeros(10,numpy.int32)
for i in range (10):
    tf_variable[i]=i
tf_variable

Tensorflow代码如下所示:除了它给出错误

import tensorflow as tf
var=tf.get_variable('var',initializer=tf.zeros([10],tf.int32),trainable=False)
itr=tf.constant(0)
sess=tf.Session()
sess.run(tf.global_variables_initializer()) #initializing variables


print('itr=',sess.run(itr))
def w_c(itr,var):
    return(tf.less(itr,10))
def w_b(itr,var):
    var=tf.assign(var[1],9) #lets say i want to modify index 1 of variable var
    itr=tf.add(itr,1)
    return [itr,var] #these tensors when returning actually get called


OP=tf.while_loop(w_c,w_b,[itr,var],parallel_iterations=1,back_prop=False)
print(sess.run(OP))

谢谢

2 个答案:

答案 0 :(得分:0)

这是一件非常独特的事情,我敢肯定,如果您进一步详细说明问题,可以为您提供更好的帮助,但是如果您打算在tf.variable中更改变量,这就是我的建议

tf_Variable=tf.random_normal([1,10])
array=tf.Session().run(tf_Variable)
print(array)
  

array([[1.8884579,-1.4278126,-1.5084593, 2.2028043 ,0.10910247,           -1.6836789,0.41359457,2.0960712,0.5169063,-0.66555417]],         dtype = float32)

array[0][3]=2
print(array)
  

array([[1.8884579,-1.4278126,-1.5084593, 2。,0.10910247,           -1.6836789,0.41359457,2.0960712,0.5169063,-0.66555417]],         dtype = float32)

如果您喜欢As is explained here

,可以再次将其输入到tf变量中

答案 1 :(得分:0)

在CPU上进行“绕行”并不总是可行的(失去梯度)。这是一种在TensorFlow中实现numpy示例的可能性(受this post和我对this other post的回答启发)

import tensorflow as tf

tf_variable = tf.Variable(tf.ones([10]))

def body(i, v):
    index = i
    new_value = tf.to_float(i)
    delta_value = new_value - v[index:index+1]
    delta = tf.SparseTensor([[index]], delta_value, (10,))
    v_updated = v + tf.sparse_tensor_to_dense(delta)
    return tf.add(i, 1), v_updated


_, updated = tf.while_loop(lambda i, _: tf.less(i, 10), body, [0, tf_variable])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(tf_variable))
    print(sess.run(updated))

此打印

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
[0. 1. 2. 3. 4. 5. 6. 7. 8. 9.]