在函数中更改tensorflow变量

时间:2017-09-01 21:23:49

标签: python tensorflow

当我执行以下操作时,我得到了奇怪的结果:(我只想访问函数中的图形变量)。

import tensorflow as tf


def mul(x):
    p = tf.get_variable('vara', shape=())
    return x * x + p

x = tf.placeholder(dtype=tf.float32, shape=())
v = tf.Variable(1.0, name='vara')

out = mul(x)


with tf.Session() as sess:

    # init vars
    sess.run(tf.global_variables_initializer())


    print(26.0, sess.run(out, feed_dict={x: 5.0}))
    # 26.0 23.9039 ??

    print(1.0, sess.run(v))
    # 1.0 1.0

    # v+= 2
    sess.run(v.assign_add(2.0))        
    print(3.0, sess.run(v))
    # 3.0 3.0


    print(28.0, sess.run(out, feed_dict={x: 5.0}))
    # 28.0 25.5912

1 个答案:

答案 0 :(得分:0)

想出来:

import tensorflow as tf

# create a new var using scope
def new_var(scope_name, var, shape=None):
    with tf.variable_scope(scope_name) as varscope:
        inputs_1 = tf.constant(1.0, shape=())
        v = tf.get_variable(var, initializer=inputs_1, shape=shape)
        varscope.reuse_variables()
        return v

# get var whenever in an arbitrary place in the graph
def get_var(scope_name, var, shape=None):
    with tf.variable_scope(scope_name, reuse=True) as varscope:
        v = tf.get_variable(var, shape)
        return v

# fx we want to access a random part of the graph
def mul(x):
    p = get_var('foo', 'v')
    return x * x + p

# init regular graph
x = tf.placeholder(dtype=tf.float32, shape=(), name='x')
v = new_var('foo', 'v')

out = mul(x)

with tf.Session() as sess:

    # init vars
    sess.run(tf.global_variables_initializer())

    print(26.0, sess.run(out, feed_dict={x: 5.0}))
    # 26.0 26.0 ??

    print(1.0, sess.run(v))
    # 1.0 1.0

    sess.run(v.assign_add(2.0))

    # 3.0 3.0
    print(3.0, sess.run(v))


    print(28.0, sess.run(out, feed_dict={x: 5.0}))
    # 28.0 28.0

    tf.summary.FileWriter("/Users/waf/Desktop/logs/", sess.graph).close()