在Tensorflow图中交换两个变量的值的最有效方法是什么?

时间:2018-09-21 11:45:10

标签: tensorflow

给出一个带有两个变量var1和var2的Tensorflow图,我想将var1的值赋给var2,反之亦然。一种简单的方法是(发布结束后的MWE)

var_tmp = var1.eval(session=sess)
sess.run([tf.assign(var1, var2])
sess.run([tf.assign(var2, var_tmp)])

但是,如果有几对这样的变量(例如,模型参数和关联的指数加权移动平均值),则此方法很快就会变得相当慢,因为run每次被调用三次,实际上会造成内存泄漏(https://github.com/tensorflow/tensorflow/issues/4151#issuecomment-244089247)。此外,由于将var_tmp放在CPU上,因此如果var1和var2在GPU上,由于数据传输,执行速度甚至会变慢。

对于几对变量,我想避免创建图的完整副本来保存临时变量。

是否可以定义单个操作来为一对变量执行此操作?甚至更好,几对?

MWE:

import tensorflow as tf

var1 = tf.Variable(1)  # 'Variable:0' 
var2 = tf.Variable(2)  # 'Variable_1:0'
sess = tf.Session()
sess.run(tf.global_variables_initializer())

var_tmp = var1.eval(session=sess)
sess.run([tf.assign(var1, var2)])
sess.run([tf.assign(var2, var_tmp)])

print(var1.eval(session=sess))
print(var2.eval(session=sess))

2 个答案:

答案 0 :(得分:1)

最好的方法是使用资源变量(1.11之后是tf.enable_resource_variables(),之后是tf.get_variable_scope().set_use_resource(True))和类似这样的图

 a_value = a.read_value()
 b_value = b.read_value()
 with tf.control_dependencies([a_value, b_value]):
   ops = a.assign(b_value), b.assign(a_value)

 sess.run(ops)

答案 1 :(得分:0)

在不深入研究TensorFlow的细节的情况下,我建议您使用此-

var1, var2 = var2, var1

如果您想交换多对,请尝试-

for var1, var2 in pairs:
    var1, var2 = var2, var1

希望这会有所帮助。