Tensorflow:有条件地添加变量范围

时间:2017-01-17 11:25:17

标签: python scope tensorflow

我想在张量流中有条件地改变变量范围。

例如,如果scope是字符串或None

if scope is None:

        a = tf.get_Variable(....)
        b = tf.get_Variable(....)
else:
    with tf.variable_scope(scope):

        a = tf.get_Variable(....)
        b = tf.get_Variable(....)

但我不想让a= ...b= ...部分写成双倍。我只想让if ... else ...确定范围,然后从那里做其他所有事情。

关于我如何做到这一点的任何想法?

2 个答案:

答案 0 :(得分:2)

感谢@keveman让我走上正轨。虽然我无法使他的答案工作,但他让我走上了正确的轨道:我需要的是一个空洞的范围,因此以下工作:

class empty_scope():
     def __init__(self):
         pass
     def __enter__(self):
         pass
     def __exit__(self, type, value, traceback):
         pass

def cond_scope(scope):
    return empty_scope() if scope is None else tf.variable_scope(scope)

之后我可以这样做:

with cond_scope(scope):

    a = tf.get_Variable(....)
    b = tf.get_Variable(....)

有关python中with的更多信息,请参阅: The Python "with" Statement by Example

答案 1 :(得分:1)

这不是特定于TensorFlow,而是一般的Python语言问题,仅仅是FYI。无论如何,您可以使用包装器上下文管理器实现您想要的操作,如下所示:

class cond_scope(object):
  def __init__(self, condition, contextmanager):
    self.condition = condition
    self.contextmanager = contextmanager
  def __enter__(self):
    if self.condition:
      return self.contextmanager.__enter__()
  def __exit__(self, *args):
    if self.condition:
      return self.contextmanager.__exit__(*args)

with cond_scope(scope is not None, scope):
  a = tf.get_variable(....)
  b = tf.get_variable(....)

编辑:修正了代码。