如何在TensorFlow中取消引用_ref张量类型?

时间:2016-02-03 10:52:14

标签: python reference tensorflow dereference

如何将参考张量类型转换为值张量类型?

我发现的唯一方法是向张量添加零。有没有方便的方法?

assign以下是参考类型的张量。如何摆脱_ref

import tensorflow as tf

counter = tf.Variable(0, name="counter")

zero = tf.constant(0)
one = tf.constant(1)

new_counter = tf.add(counter, one)
assign = tf.assign(counter, new_counter) # dtype=int32_ref
result = tf.add(assign, zero) # dtype=int32
result2 = tf.convert_to_tensor(assign) # dtype=int32_ref
# result3 = assign.value() # has no attribute value

1 个答案:

答案 0 :(得分:4)

一般情况下,您应该能够在tf.foo_ref - 类型张量预期的任何地方使用tf.foo - 张量张量。 TensorFlow操作将隐式取消引用它们的输入参数(除非明确期望引用张量,例如在tf.assign()中)。

取消引用张量的最简单方法是使用tf.identity(),如下所示:

counter = tf.Variable(0)
assert counter.dtype == tf.int32_ref

counter_val = tf.identity(counter)
assert counter_val.dtype == tf.int32

请注意,这会回答您的问题,但可能会出现令人惊讶的语义,因为tf.identity() 不会复制基础缓冲区。因此,上例中的countercounter_val共享相同的缓冲区,counter的修改将反映在counter_val中:

counter = tf.Variable(0)
counter_val = tf.identity(counter)  # Take alias before the `assign_add` happens.
counter_update = counter.assign_add(1)

with tf.control_dependencies([counter_update]):
  # Force a copy after the `assign_add` happens.
  result = counter_val + 0

sess = tf.Session()
sess.run(tf.initialize_all_variables())

print sess.run(result)  # ==> 1  (result has effect of `assign_add`)