在Tensorflow中使用tf.while_loop更新变量

时间:2018-09-08 14:57:08

标签: python tensorflow

我想更新Tensorflow中的变量,因此我使用tf.while_loop像这样:

RuntimeError: Failed to init API, possibly an invalid tessdata path: 
C:\Users\boodn\Anaconda3\envs\

这是我正在尝试做的一个例子。发生的错误是a = tf.Variable([0, 0, 0, 0, 0, 0] , dtype = np.int16) i = tf.constant(0) size = tf.size(a) def condition(i, size, a): return tf.less(i, size) def body(i, size, a): a = tf.scatter_update(a, i , i) return [tf.add(i, 1), size, a] r = tf.while_loop(condition, body, [i, size, a]) 。在Tensorflow中更新变量的合适方法是什么?

1 个答案:

答案 0 :(得分:1)

直到一个代码执行并执行,这才很明显。就像这样pattern

import tensorflow as tf


def cond(size, i):
    return tf.less(i,size)

def body(size, i):

    a = tf.get_variable("a",[6],dtype=tf.int32,initializer=tf.constant_initializer(0))
    a = tf.scatter_update(a,i,i)

    tf.get_variable_scope().reuse_variables() # Reuse variables
    with tf.control_dependencies([a]):
        return (size, i+1)

with tf.Session() as sess:

    i = tf.constant(0)
    size = tf.constant(6)
    _,i = tf.while_loop(cond,
                    body,
                    [size, i])

    a = tf.get_variable("a",[6],dtype=tf.int32)

    init = tf.initialize_all_variables()
    sess.run(init)

    print(sess.run([a,i]))

输出为

  

[array([0,1,2,3,4,5]),6]

  1. tf.get_variable使用这些参数获取现有变量或创建一个新变量。
  2. tf.control_dependencies这是先发生的关系。在这种情况下,我了解到scatter_update发生在while递增并返回之前。没有这个,它不会更新。

注意:我不太了解错误的含义或原因。我也明白。