Tensorflow tf.trainable_variables(scope =“”)函数

时间:2018-06-22 12:03:19

标签: python tensorflow jupyter-notebook

我尝试通过以下替换来简化我的张量流代码:

f_vars = tf.trainable_variables("foo")

代替以前的语法:

t_vars = tf.trainable_variables()
f_vars = [var for var in t_vars if var.name.startswith('foo')]

之前,我尝试使用:p来更新tensorflow到最新版本:

!pip install --upgrade tensorflow

在jupyter笔记本中。

检查版本,它返回了

TensorFlow Version: 1.8.0

当我尝试运行它时,tensorflow返回以下错误。

TypeError: trainable_variables() takes 0 positional arguments but 1 was given

这是怎么了?在tensorflow文档中,您可以为tf.trainable_variables()命令的作用域插入参数。 -> https://www.tensorflow.org/api_docs/python/tf/trainable_variables

1 个答案:

答案 0 :(得分:1)

函数调用tf.trainable_scope('foo')需要定义一个名为'foo'的变量范围。

例如:

a = tf.Variable(1, name='a')
with tf.variable_scope('foo'):
  b = tf.Variable(1, name='b')

要获取可训练的变量,请致电:

tf.trainable_variables()
# return variables named 'a' and 'foo/b'

tf.trainable_variables('foo')
# returns variables named 'foo/b'