TensorFlow中图形集合的目的是什么?

时间:2015-12-12 02:26:40

标签: python tensorflow

API讨论Graph Collections,其中code判断是通用密钥/数据存储。这些藏品的目的是什么?

2 个答案:

答案 0 :(得分:16)

请记住,Tensorflow是一个用于指定然后执行计算数据流图的系统。图集合用作跟踪构造的图形以及它们必须如何执行的一部分。例如,当您创建某些类型的操作时,例如tf.train.batch_join,添加操作的代码也会向QUEUE_RUNNERS图表集合添加一些队列运行程序。稍后,当您拨打start_queue_runners()时,默认情况下,它会查看QUEUE_RUNNERS集合,以了解要启动哪些参赛者。

答案 1 :(得分:4)

我认为到目前为止至少有两个好处:

  1. 当您在多个GPU或计算机上分发程序时,可以方便地从同一集合中的不同设备收集损失。使用tf.add_n添加它们以累积损失。
  2. 以我自己的方式更新一组特定的变量,如权重和偏差。
  3. 例如:

    import tensorflow as tf    
    w = tf.Variable([1,2,3], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)    
    w2 = tf.Variable([11,22,32], collections=[tf.GraphKeys.WEIGHTS], dtype=tf.float32)
    weight_init_op = tf.variables_initializer(tf.get_collection_ref(tf.GraphKeys.WEIGHTS))
    sess = tf.InteractiveSession()
    weight_init_op.run()
    for vari in tf.get_collection_ref(tf.GraphKeys.WEIGHTS): 
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, vari.assign(0.2 * vari))
    weight_update_ops = tf.get_collection_ref(tf.GraphKeys.UPDATE_OPS)
    for op in weight_update_ops:
        print(op.eval())
    

    输出:

    [0.2 0.4 0.6]
    [2.2 4.4 6.4]