在运行时动态选择variable_scope

时间:2018-02-01 12:23:28

标签: variables if-statement tensorflow deep-learning

我想用一些张量的值来改变variable_scope。举个简单的例子,我定义了一个非常简单的代码:

import tensorflow as tf

def calculate_variable(scope):
    with tf.variable_scope(scope or type(self).__name__, reuse=tf.AUTO_REUSE):
        w = tf.get_variable('ww', shape=[5], initializer=tf.truncated_normal_initializer(mean=0.0, stddev=0.1))
        return w

w = calculate_variable('in_first')
w1 = calculate_variable('in_second')

功能很简单。它只返回在某个变量范围内定义的值。现在,'w'和'w1'会有不同的值。

我想要做的是通过张量的某些条件选择此变量范围。假设我有两个张量x,y,如果它们的值相同,我想从上面的函数获得某些变量范围的值。

x = tf.constant(3)
y = tf.constant(3)
condi = tf.cond(tf.equal(x, y), lambda: 'in_first', lambda: 'in_second')

w_cond = calculate_variable(condi)

我尝试了很多其他方法并搜索了互联网。但是,每当我想以与此示例类似的方式从张量条件确定variable_scope时,它就会显示错误。

TypeError: Using a `tf.Tensor` as a Python `bool` is not allowed. Use `if t is not None:` instead of `if t:` to test if a tensor is defined, and use TensorFlow ops such as tf.cond to execute subgraphs conditioned on the value of a tensor.

有没有好的解决方法?

1 个答案:

答案 0 :(得分:0)

你说的方式,这是不可能的。 variable_scope类显式检查name_or_scope参数是字符串还是VariableScope实例:

  ...
  if not isinstance(self._name_or_scope,
                    (VariableScope,) + six.string_types):
    raise TypeError("VariableScope: name_or_scope must be a string or "
                    "VariableScope.")

它不接受Tensor。这是合理的,因为变量范围是图形定义的一部分,无法动态定义变量。

最支持的表达式是:

x = tf.constant(3)
y = tf.constant(3)
w_cond = tf.cond(tf.equal(x, y), 
                 lambda: calculate_variable('in_first'), 
                 lambda: calculate_variable('in_second'))

...您可以在运行时选择两个变量中的任何一个。