在Tensorflow中将变量增加为副本

时间:2017-01-31 10:44:50

标签: tensorflow

我目前有以下代码,我想用它来提供"流"递增整数。

import tensorflow as tf
...
record_count = tf.user_ops.my_custom_op(...) # something I write in C++ of Python
...
my_variable = tf.Variable(0, dtype=dtypes.int64)
my_var_incremented = my_variable.assign_add(math_ops.to_int64(record_count))
queued_increment = tf.train.input.batch((my_variable,), 1)

但问题是queued_increment只是对my_variable的引用,当我只想在递增后将my_variable的副本排入队列。

这是解决这个问题的正确方法,还是我错过了什么?

1 个答案:

答案 0 :(得分:1)

当与其他有状态构造(例如队列)交互时,当前的TensorFlow变量具有不幸的语义。问题源于“引用类型”(请注意my_variable.dtypetf.int64_ref,这意味着它是可变张量引用),大多数操作 - 包括队列 - 隐式“ dereference“通过创建一个”可变“缓冲区别名的”常量“张量。我们正在修复TensorFlow的变量内存模型中的这个错误,但变化不在公共API中。

与此同时,您最好的选择是在将变量插入队列时强制复制。这个最简单的解决方案依赖于未记录的行为,但tf.QueueBase.enqueue_many()将始终将其值复制到队列中,即使您将单个元素排入队列也是如此。通过tf.train.batch()使用时,您只需要重新整形变量(例如使用tf.expand_dims())并传递enqueue_many=True。例如:

my_variable = tf.Variable(0, dtype=dtypes.int64)
# ...
queued_increment = tf.train.batch((tf.expand_dims(my_variable, 1),), 1,
                                  enqueue_many=True)