如何在tensorlfow中定义我自己的函数中的tf.Variable?

时间:2017-05-15 11:43:18

标签: tensorflow

import tensorflow as tf
import numpy as np


def myfunction():
    return_variable = tf.Variable(initial_value=0.0, dtype=tf.float32)
    return return_variable

a = np.random.randint(1, 5, size=(3, 2, 2))
a_variable = tf.Variable(a, tf.float32)
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    print('a_variable')
    print(sess.run(a_variable))
    print('myfunction')
    print(sess.run(myfunction()))

我想在我的函数中初始化一个变量。当我运行我的代码时,我收到错误“尝试使用未初始化的值Variable_1”。我想知道为什么函数中定义的变量不能被'tf.initialize_all_variables()'初始化。我怎么能在我的函数中使用变量?

1 个答案:

答案 0 :(得分:0)

变量仅在您调用myfunction()时创建,如果已在运行tf.initialize_all_variables()时已创建,则进行初始化。所以你只需要在初始化所有变量之前调用myfunction()。通常,如果需要,可以使用myfunction()构建整个图表,然后在训练之前初始化,这样就不会有任何问题...请注意myfunction()创建一个新的(相同的) )每次调用它时,它都不返回指向同一个底层变量的指针。

请参阅以下代码:

import tensorflow as tf
import numpy as np
def myfunction():
    return_variable = tf.Variable(initial_value=0.0, dtype=tf.float32)
    return return_variable
a = np.random.randint(1, 5, size=(3, 2, 2))
a_variable = tf.Variable(a, tf.float32)
with tf.Session() as sess:
    x = myfunction()   # Create variable x, not initialized yet
    sess.run(tf.initialize_all_variables())  # Initialize all variables that already exist in your graph, including x, but y does not exist yet
    print('a_variable')
    print(sess.run(a_variable))
    print('x')  
    print(sess.run(x))  # Works fine: x is initialized
    print('y')
    y = myfunction()    # Create variable y, not initialized
    print(sess.run(y))  # This crashes, because y is not initialized