查找tensorflow操作依赖的所有变量

时间:2017-03-17 10:50:51

标签: tensorflow

有没有办法找到给定操作(通常是损失)所依赖的所有变量? 我想使用此功能,然后使用各种optimizer.minimize()组合将此集合传递到tf.gradients()set().intersection()

到目前为止,我已找到op.op.inputs并尝试了一个简单的BFS,但我永远不会碰到Variabletf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

返回的slim.get_variables()个对象

相应的'Tensor.op._id and Variables.op._id`字段之间似乎存在对应关系,但我不确定这是否应该依赖。

或者我可能不想在第一时间这样做? 我当然可以在构建我的图形时精心构建我不相交的变量集,但是如果我改变模型就很容易错过。

1 个答案:

答案 0 :(得分:4)

documentation for tf.Variable.op并不是特别清楚,但它确实引用the implementation of a tf.Variable中使用的关键tf.Operation:依赖于tf.Variable的任何操作都将在路径上从那次行动。由于tf.Operation对象是可清除的,因此您可以将其用作将dict对象映射到相应tf.Operation对象的tf.Variable的键,然后像以前一样执行BFS :

op_to_var = {var.op: var for var in tf.trainable_variables()}

starting_op = ...
dependent_vars = []

queue = collections.deque()
queue.append(starting_op)

visited = set([starting_op])

while queue:
  op = queue.popleft()
  try:
    dependent_vars.append(op_to_var[op])
  except KeyError:
    # `op` is not a variable, so search its inputs (if any). 
    for op_input in op.inputs:
      if op_input.op not in visited:
        queue.append(op_input.op)
        visited.add(op_input.op)