在张量流图

时间:2017-08-31 22:57:24

标签: python tensorflow

我想在图表中创建一个计数器。只要函数运行,它就会增加。这就是我现在正在尝试的。有任何想法吗?

# init variable
tf_i = tf.Variable(1, name='v', dtype=tf.int32, expected_shape=())

# init assign op
tf_i_plus_one = tf.assign_add(tf_i, 1)
sess.run(tf.global_variables_initializer())

# simple test fx
def test(x):
    v = tf.get_variable('v', shape=())
    return x + v

# run test fx
z = tf.placeholder(tf.float32, shape=(), name='z')
out = test(z)

print(sess.run(out, feed_dict={z: 4}))
sess.run(tf_i_plus_one)

print(sess.run(out, feed_dict={z: 4}))

我现在得到:

Attempting to use uninitialized value v_1

1 个答案:

答案 0 :(得分:0)

如果你想要的只是一个函数的计数器,请尝试以下方法:

def wrap_with_counter(fn, counter):
    def wrapped_fn(*args, **kwargs):
        # control_dependencies forces the assign op to be run even if we don't use the result
        with tf.control_dependencies([tf.assign_add(counter, 1)]):
            return fn(*args, **kwargs)
    return wrapped_fn

使用示例:

dense_counter = tf.get_variable(
    dtype=tf.int32, shape=(), name='dense_counter',
    initializer=tf.zeros_initializer())
wrapped_dense = wrap_with_counter(tf.layers.dense, dense_counter)


batch_size = 4
n_features = 8
# dummy input
x = tf.random_normal(
    shape=(batch_size, n_features), dtype=tf.float32, name='x')

out = wrapped_dense(x, 1)

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    c0 = sess.run(dense_counter)
    print(c0)  # 0
    sess.run(out)
    c1 = sess.run(dense_counter)
    print(c1)  # 1