使用tf.contrib.graph_editor克隆网络

时间:2017-08-26 14:08:47

标签: tensorflow

我有一个这样的功能,可以构建一个网络。

def build_network(inputs):
  # Some arbitrary set of variables and ops here. For example...
  out = tf.contrib.layers.fully_connected(inputs, 123)
  (...)
  return out

然后我用它来建立这样的网络。

inputs = tf.placeholder(...)
outputs = build_network(inputs)

如果我想构建更多具有相同结构但独立变量的网络,我只需要在其他变量范围和其他输入下再次调用 build_network

我的问题是:如果 build_network 不再可用,但原始网络的输入和输出是什么,我怎么能这样做呢?换句话说:如何将输出中的整个子图一直克隆到输入到另一个变量范围内,并使用自己独立的变量集但结构相同?

我的理解是,一般来说tf.contrib.graph_editor和graph_editor.copy正是我做这些事情所需的工具。但是,我找不到任何好用的例子。有什么建议吗?

2 个答案:

答案 0 :(得分:2)

回应自己,我发现了一种复制子图的方法。

from tensorflow.contrib import graph_editor as ge

# From the example above.
inputs = [tf.placeholder(...), ...]
outputs = build_network(inputs)

sgv = ge.make_view(ge.get_within_boundary_ops(
    tf.get_default_graph(),
    [t.op for t in outputs],
    [t.op for t in inputs]))

# This could be any new inputs. In this example I build new identical placeholders.
new_inputs = {p: tf.placeholder(dtype=p.dtype, shape=p.shape) for p in inputs}
new_sgv, info = ge.copy_with_input_replacements(sgv, new_inputs, dst_scope='copy')

new_inputs = [info.transformed(t) for t in inputs]
new_outputs = [info.transformed(t) for t in outputs]

但是,现在我在尝试使用网络副本时遇到了一个新问题。副本中的新变量未初始化,尝试运行tf.global_variables_initializer()也无济于事。

原因是因为这些的tf.Variable从未构建过,所以它们不是GlobalKeys.GLOBAL_VARIABLES集合的一部分。我可以很容易地找到对应于这些变量的ops以及它们在原始和副本之间的映射,但是我无法从中构建一个tf.Variable。

我发现了一些hacky解决方法来进行初始化,但它只适用于集合中的变量。

init_ops = []
for v in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
  if v.op in sgv.ops:
    init_ops.append(info.transformed(v.initializer))

...

session.run([tf.global_variables_initializer()] + init_ops)

有没有更好的方法呢?理想情况下,允许为复制的vars创建tf.Variables以将其添加到全局变量集合。或者,如果不可能,至少可以获得初始化操作,而不必找到原始网络的tf.Variable对象。

答案 1 :(得分:0)

  

注意:这个答案和OP's answer是互补的。首先阅读OP's answer

我今天在此问题上花费了4个小时。这是TensorFlow丑陋的地方之一(这就是为什么如果要进行图形操作,应该使用PyTorch的原因。)


这里的关键点是tf.Variable不是图形元素(有关here的更多内容),而是围绕3个操作的包装:Assign op,Read op和VariableV2 op(本质上是ref tensor(更多信息here)。因此,您需要在TensorFlow Framework中显式调用它。

如果我们仔细查看graph_editor的代码,尤其是transform module,我们会发现它仅在tf.Graph上运行,而未涉及TensorFlow框架中的任何内容。因此,graph_editor.copy(及类似方法)根本不触摸tf.Variable对象。它只会复制tf.Variable的构建基块的张量和运算。

  

好的,那我们怎么解决这个问题呢?

假设您具有以下变量:

var = tf.get_trainable_variables()[0]
print(var.to_proto())
# variable_name: "dense_1/kernel:0"
# initializer_name: "dense_1/kernel/Assign"
# snapshot_name: "dense_1/kernel/read:0"
# initial_value_name: "dense_1/random_uniform:0"
# trainable: true

您知道在graph_editor.copy(...)之后,您的dense_1名称范围现在为dense_1b。然后,您需要使用info.transformed(...)获取相应的操作数和张量,然后执行以下操作:

from tensorflow.core.framework import variable_pb2

var_def = variable_pb2.VariableDef()
var_def.variable_name = 'dense_1b/kernel:0'
var_def.initializer_name = "dense_1b/kernel/Assign"
var_def.snapshot_name = "dense_1b/kernel/read:0"
var_def.initial_value_name = "dense_1/random_uniform:0"
var_def.trainable = True

现在,我要强调tf.Variable documentation的以下部分:

  

variable_def:...重新创建变量对象及其内容,引用图中必须已存在的变量节点。该图未更改。

因此,tf.Variable构造函数使我们可以在现有图形元素之上创建变量包装器。这正是我们所需要的:

cloned_var = tf.Variable(variable_def=var_def)
tf.add_to_collection(tf.GraphKeys.GLOBAL_VARIABLES, cloned_var)

已解决!


我尽可能简单,明确地回答这个问题,以展示tf.Variables的基本原理。现在,您可以轻松实现更通用的代码,以自动创建新变量。

PS:我讨厌TensorFlow!