解释张量流的FLOPs配置文件结果

时间:2018-07-11 06:17:04

标签: tensorflow profiler flops

我想分析一个非常简单的神经网络模型的FLOP,该模型用于对MNIST数据集进行分类,批处理大小为128。按照官方教程进行操作时,我得到了以下模型的结果,但是我无法理解输出的某些部分。

w1 = tf.Variable(tf.random_uniform([784, 15]), name='w1')
w2 = tf.Variable(tf.random_uniform([15, 10]), name='w2')
b1 = tf.Variable(tf.zeros([15, ]), name='b1')
b2 = tf.Variable(tf.zeros([10, ]), name='b2')

hidden_layer = tf.add(tf.matmul(images_iter, w1), b1)
logits = tf.add(tf.matmul(hidden_layer, w2), b2)

loss_op = tf.reduce_sum(\
    tf.nn.softmax_cross_entropy_with_logits(logits=logits, 
                                            labels=labels_iter))
opetimizer = tf.train.AdamOptimizer(learning_rate=0.01)
train_op = opetimizer.minimize(loss_op)

images_iterlabels_iter是tf.data的迭代器,类似于占位符。

tf.profiler.profile(
    tf.get_default_graph(),
    options=tf.profiler.ProfileOptionBuilder.float_operation())

我使用此代码(相当于tfprof注释行工具中的scope -min_float_ops 1 -select float_ops -account_displayed_op_only)来分析FLOP,并得到以下结果。

Profile:
node name | # float_ops
_TFProfRoot (--/23.83k flops)
  random_uniform (11.76k/23.52k flops)
    random_uniform/mul (11.76k/11.76k flops)
    random_uniform/sub (1/1 flops)
  random_uniform_1 (150/301 flops)
    random_uniform_1/mul (150/150 flops)
    random_uniform_1/sub (1/1 flops)
  Adam/mul (1/1 flops)
  Adam/mul_1 (1/1 flops)
  softmax_cross_entropy_with_logits_sg/Sub (1/1 flops)
  softmax_cross_entropy_with_logits_sg/Sub_1 (1/1 flops)
  softmax_cross_entropy_with_logits_sg/Sub_2 (1/1 flops)

我的问题是

  1. 括号中的数字是什么意思?例如random_uniform_1 (150/301 flops),150和301是什么?
  2. 为什么_TFProfRoot括号“-”中的第一个数字是?
  3. 为什么Adam / mul和softmax_cross_entropy_with_logits_sg / Sub 1的翻牌圈?

我知道这么长时间阅读一个问题会令人沮丧,但是一个绝望的男孩无法从官方文档中找到相关信息,需要你们的帮助。

1 个答案:

答案 0 :(得分:1)

我会尝试的:

(1)在此示例中,看起来第一个数字是“自我”翻牌,第二个数字表示命名范围内的“总数”翻牌。例如:对于分别命名为random_uniform(如果有这样的节点)的3个节点,random_uniform / mul,random_uniform / sub,它们分别耗费11.76k,11.76k和1 flop,总共耗费23.52k flop。

另一个例子:23.83k = 23.52k + 300。

这有意义吗?

(2)根节点是探查器添加的“虚拟”顶级节点,它没有“自我”触发器,换句话说,它具有零自我触发器。

(3)不确定为什么是1。如果可以使用print(sess.graph_def)打印GraphDef并找出该节点的真正含义,这将有所帮助

希望这会有所帮助。