我试图在网络中构建多个分支。所以我在代码中使用了tf.case。但我发现tf.case总是构建最后一个可调用函数两次,这将导致变量错误:“变量XXX已经存在”(我通过slim创建变量,变量范围“case / If_x”将不存在,这是为什么我会得到错误)。这是一个带有输出的测试程序。
import numpy as np
import tensorflow as tf
slim = tf.contrib.slim
def fn1(X, Y):
with tf.variable_scope("fn1", reuse=False):
w = tf.Variable(1.0, name="w")
#w = slim.variable(name="w", shape=())
return X*w, Y*w
def fn2(X, Y):
with tf.variable_scope("fn2", reuse=False):
w = tf.Variable(2.0, name="w")
#w = slim.variable(name="w", shape=())
return X*w, Y*w
def fn3(X, Y):
with tf.variable_scope("fn3", reuse=False):
w = tf.Variable(3.0, name="w")
#w = slim.variable(name="w", shape=())
return X*w, Y*w
class Test:
def __init__(self):
self.Z = tf.placeholder(dtype=tf.int32, shape=())
self.X = tf.Variable(1.0, name="X")
self.Y = tf.Variable(2.0, name="Y")
def build(self):
self.result = tf.case(
pred_fn_pairs=[
(tf.equal(self.Z, 10), lambda : fn3(self.X, self.Y)),
(tf.equal(self.Z, 20), lambda : fn2(self.X, self.Y)),
(tf.equal(self.Z, 30), lambda : fn1(self.X, self.Y))],
exclusive=False)
test = Test()
test.build()
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
tvars = tf.trainable_variables()
tvars_vals = sess.run(tvars)
for var, val in zip(tvars, tvars_vals):
print(var.name)
aa = sess.run(test.result, feed_dict={test.Z:20})
print aa
输出结果为:
X:0
Y:0
case/If_0/fn1/w:0
case/If_0/fn1_1/w:0
case/If_1/fn2/w:0
case/If_2/fn3/w:0
(2.0, 4.0)