Tensorflow中计算图中的节点数没有意义

时间:2017-05-19 15:46:32

标签: python tensorflow

在下面的代码中,我已经评论了每行添加到计算图中的节点数:

import tensorflow as tf

tf.reset_default_graph()

W = tf.Variable([.3], tf.float32)#4
b = tf.Variable([-.3], tf.float32)#4
x = tf.placeholder(tf.float32)#1
y = tf.placeholder(tf.float32)#1
linear_model = W * x + b#2
loss = tf.reduce_sum(tf.square(linear_model - y)) # 7
optimizer = tf.train.GradientDescentOptimizer(0.01) # 0
train = optimizer.minimize(loss)#59
init = tf.global_variables_initializer()

但这些数字对我来说没有意义。例如,为什么tf.Variable行,图表中添加了四个节点?为什么optimizer.minimize(loss)添加了59个节点?并且tf.train.GradientDescentOptimizer没有添加任何节点?

1 个答案:

答案 0 :(得分:2)

因为这些行是Python包装器,可以转换为更低级别的TensorFlow操作。

通常,来自gen_xxx_ops.py文件的任何操作都会直接转换为单个TensorFlow节点(即gen_math_ops.py),但tf.add中定义的内容类似xxx_ops.py math_ops.py做一些额外的Python内容,可以转换为多个节点。

例如,考虑tf.reduce_sum。在Jupyter中使用inspect模块或tf.reduce_sum??,您可以在tensorflow/python/ops/math_ops.py中看到它的定义,具有以下定义:

  return gen_math_ops._sum(
      input_tensor,
      _ReductionDims(input_tensor, axis, reduction_indices),
      keep_dims,
      name=name)

_sum调用会创建一个Sum节点,但ReductionDims还会创建TensorFlow节点,用于定义求和的起始和结束索引。

如果你看一下tf.train.GradientDescentOptimizer的定义,你会发现构造函数没有定义任何计算,它只保存lr参数。计算已添加到minimizeapply_gradients

您可以使用以下辅助函数来确切地查看添加了哪些节点

import tensorflow as tf
from pprint import pprint

tf.reset_default_graph()

with capture_ops() as ops:
    W = tf.Variable([.3], tf.float32)
pprint(ops)

其中capture_ops定义为

import contextlib
@contextlib.contextmanager
def capture_ops():
  """Captures any ops added to the tf Graph within this block."""
  from tensorflow.python.framework import ops
  old_create_op =  ops.Graph.create_op
  op_list = []
  def new_create_op(graph_object, op_type, inputs, dtypes, input_types=None, name=None, attrs=None, op_def=None, compute_shapes=True, compute_device=True):
    # todo: remove keyword args
    op = old_create_op(graph_object, op_type=op_type, inputs=inputs, dtypes=dtypes, input_types=input_types, name=name, attrs=attrs, op_def=op_def, compute_shapes=compute_shapes, compute_device=compute_device)
    op_list.append(op)
    return op
  ops.Graph.create_op = new_create_op
  yield op_list
  ops.Graph.create_op = old_create_op