我有一个这样的功能,可以构建一个网络。
def build_network(inputs):
# Some arbitrary set of variables and ops here. For example...
out = tf.contrib.layers.fully_connected(inputs, 123)
(...)
return out
然后我用它来建立这样的网络。
inputs = tf.placeholder(...)
outputs = build_network(inputs)
如果我想构建更多具有相同结构但独立变量的网络,我只需要在其他变量范围和其他输入下再次调用 build_network 。
我的问题是:如果 build_network 不再可用,但原始网络的输入和输出是什么,我怎么能这样做呢?换句话说:如何将输出中的整个子图一直克隆到输入到另一个变量范围内,并使用自己独立的变量集但结构相同?
我的理解是,一般来说tf.contrib.graph_editor和graph_editor.copy正是我做这些事情所需的工具。但是,我找不到任何好用的例子。有什么建议吗?
答案 0 :(得分:2)
回应自己,我发现了一种复制子图的方法。
from tensorflow.contrib import graph_editor as ge
# From the example above.
inputs = [tf.placeholder(...), ...]
outputs = build_network(inputs)
sgv = ge.make_view(ge.get_within_boundary_ops(
tf.get_default_graph(),
[t.op for t in outputs],
[t.op for t in inputs]))
# This could be any new inputs. In this example I build new identical placeholders.
new_inputs = {p: tf.placeholder(dtype=p.dtype, shape=p.shape) for p in inputs}
new_sgv, info = ge.copy_with_input_replacements(sgv, new_inputs, dst_scope='copy')
new_inputs = [info.transformed(t) for t in inputs]
new_outputs = [info.transformed(t) for t in outputs]
但是,现在我在尝试使用网络副本时遇到了一个新问题。副本中的新变量未初始化,尝试运行tf.global_variables_initializer()也无济于事。
原因是因为这些的tf.Variable从未构建过,所以它们不是GlobalKeys.GLOBAL_VARIABLES集合的一部分。我可以很容易地找到对应于这些变量的ops以及它们在原始和副本之间的映射,但是我无法从中构建一个tf.Variable。
我发现了一些hacky解决方法来进行初始化,但它只适用于集合中的变量。
init_ops = []
for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
if v.op in sgv.ops:
init_ops.append(info.transformed(v.initializer))
...
session.run([tf.global_variables_initializer()] + init_ops)
有没有更好的方法呢?理想情况下,允许为复制的vars创建tf.Variables以将其添加到全局变量集合。或者,如果不可能,至少可以获得初始化操作,而不必找到原始网络的tf.Variable对象。
答案 1 :(得分:0)
注意:这个答案和OP's answer是互补的。首先阅读OP's answer。
我今天在此问题上花费了4个小时。这是TensorFlow丑陋的地方之一(这就是为什么如果要进行图形操作,应该使用PyTorch的原因。)
这里的关键点是:tf.Variable
不是图形元素(有关here的更多内容),而是围绕3个操作的包装:Assign
op,Read
op和VariableV2
op(本质上是ref tensor
(更多信息here)。因此,您需要在TensorFlow Framework中显式调用它。
如果我们仔细查看graph_editor
的代码,尤其是transform module,我们会发现它仅在tf.Graph
上运行,而未涉及TensorFlow框架中的任何内容。因此,graph_editor.copy
(及类似方法)根本不触摸tf.Variable
对象。它只会复制tf.Variable
的构建基块的张量和运算。
好的,那我们怎么解决这个问题呢?
假设您具有以下变量:
var = tf.get_trainable_variables()[0]
print(var.to_proto())
# variable_name: "dense_1/kernel:0"
# initializer_name: "dense_1/kernel/Assign"
# snapshot_name: "dense_1/kernel/read:0"
# initial_value_name: "dense_1/random_uniform:0"
# trainable: true
您知道在graph_editor.copy(...)
之后,您的dense_1
名称范围现在为dense_1b
。然后,您需要使用info.transformed(...)
获取相应的操作数和张量,然后执行以下操作:
from tensorflow.core.framework import variable_pb2
var_def = variable_pb2.VariableDef()
var_def.variable_name = 'dense_1b/kernel:0'
var_def.initializer_name = "dense_1b/kernel/Assign"
var_def.snapshot_name = "dense_1b/kernel/read:0"
var_def.initial_value_name = "dense_1/random_uniform:0"
var_def.trainable = True
现在,我要强调tf.Variable
documentation的以下部分:
variable_def
:...重新创建变量对象及其内容,引用图中必须已存在的变量节点。该图未更改。
因此,tf.Variable
构造函数使我们可以在现有图形元素之上创建变量包装器。这正是我们所需要的:
cloned_var = tf.Variable(variable_def=var_def)
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, cloned_var)
已解决!
我尽可能简单,明确地回答这个问题,以展示tf.Variables
的基本原理。现在,您可以轻松实现更通用的代码,以自动创建新变量。
PS:我讨厌TensorFlow!