为什么tf.case会构建两次可调用函数?

时间:2018-06-05 20:31:19

标签: tensorflow

我试图在网络中构建多个分支。所以我在代码中使用了tf.case。但我发现tf.case总是构建最后一个可调用函数两次,这将导致变量错误:“变量XXX已经存在”(我通过slim创建变量,变量范围“case / If_x”将不存在,这是为什么我会得到错误)。这是一个带有输出的测试程序。

import numpy as np
import tensorflow as tf

slim = tf.contrib.slim

def fn1(X, Y):
    with tf.variable_scope("fn1", reuse=False):
       w = tf.Variable(1.0, name="w")
       #w = slim.variable(name="w", shape=())
    return X*w, Y*w

def fn2(X, Y):
    with tf.variable_scope("fn2", reuse=False):
       w = tf.Variable(2.0, name="w")
       #w = slim.variable(name="w", shape=())
    return X*w, Y*w

def fn3(X, Y):
    with tf.variable_scope("fn3", reuse=False):
       w = tf.Variable(3.0, name="w")
       #w = slim.variable(name="w", shape=())
    return X*w, Y*w

class Test:

    def __init__(self):
        self.Z = tf.placeholder(dtype=tf.int32, shape=())
        self.X = tf.Variable(1.0, name="X")
        self.Y = tf.Variable(2.0, name="Y")

    def build(self):
        self.result = tf.case(
            pred_fn_pairs=[
                    (tf.equal(self.Z, 10), lambda : fn3(self.X, self.Y)),
                    (tf.equal(self.Z, 20), lambda : fn2(self.X, self.Y)),
                    (tf.equal(self.Z, 30), lambda : fn1(self.X, self.Y))],
                    exclusive=False)

test = Test()
test.build()

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())

    tvars = tf.trainable_variables()
    tvars_vals = sess.run(tvars)

    for var, val in zip(tvars, tvars_vals):
        print(var.name) 

    aa = sess.run(test.result, feed_dict={test.Z:20})
    print aa

输出结果为:

X:0
Y:0
case/If_0/fn1/w:0
case/If_0/fn1_1/w:0
case/If_1/fn2/w:0
case/If_2/fn3/w:0
(2.0, 4.0)

0 个答案:

没有答案