有没有办法找到给定操作(通常是损失)所依赖的所有变量?
我想使用此功能,然后使用各种optimizer.minimize()
组合将此集合传递到tf.gradients()
或set().intersection()
。
到目前为止,我已找到op.op.inputs
并尝试了一个简单的BFS,但我永远不会碰到Variable
或tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
slim.get_variables()
个对象
相应的'Tensor.op._id and
Variables.op._id`字段之间似乎存在对应关系,但我不确定这是否应该依赖。
或者我可能不想在第一时间这样做? 我当然可以在构建我的图形时精心构建我不相交的变量集,但是如果我改变模型就很容易错过。
答案 0 :(得分:4)
documentation for tf.Variable.op
并不是特别清楚,但它确实引用the implementation of a tf.Variable
中使用的关键tf.Operation
:依赖于tf.Variable
的任何操作都将在路径上从那次行动。由于tf.Operation
对象是可清除的,因此您可以将其用作将dict
对象映射到相应tf.Operation
对象的tf.Variable
的键,然后像以前一样执行BFS :
op_to_var = {var.op: var for var in tf.trainable_variables()}
starting_op = ...
dependent_vars = []
queue = collections.deque()
queue.append(starting_op)
visited = set([starting_op])
while queue:
op = queue.popleft()
try:
dependent_vars.append(op_to_var[op])
except KeyError:
# `op` is not a variable, so search its inputs (if any).
for op_input in op.inputs:
if op_input.op not in visited:
queue.append(op_input.op)
visited.add(op_input.op)