如何获得对ReadVariableOp正在读取的Variable对象的引用?

时间:2019-01-24 08:22:17

标签: python tensorflow

我具有以下形式的代码:

project_ops = []
for op in tf.get_default_graph().get_operations():
  if op.type == 'Conv2D':
    activations, kernel = op.inputs
    batch_size, height, width, num_channels = activations.shape.as_list()
    kernel_size_height, kernel_size_width, input_channels, output_channels = kernel.shape.as_list()
    print(activations.shape.as_list(), kernel.shape.as_list())
    project_ops.append(tf.assign(kernel, Orthoganalize(kernel, [height, width])))

这不起作用,因为kernel不是变量,而是ReadVariableOp。我希望从中获取变量,但是它似乎没有在Python中可访问的变量的引用?

1 个答案:

答案 0 :(得分:0)

这可能不是一种好方法,但是您可以使用以下功能:

def get_var_from(readOp): # make sure to pass a tf.Operation with
                          # readOp.type == 'ReadVariableOp'
    handleOp = readOp.inputs[0].op
    for var in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
        if var.op is handleOp:
            return var

这将使用一个tf.Variable(*)唯一的VarHandleOp实例,并读取所有变量以找到一个使用Handle的变量。 (请注意,ReadVariableOp不是变量唯一的。可用于重新读取变量。)

我看不到从句柄到tf.Variable -instance的链接,也许Python API中没有任何链接,也许它是单向链接。

(*)据我了解,您使用ResourceVariables,因此此描述可能仅引用ResourceVariables。