Tensorflow,tf.where的参数

时间:2018-07-01 11:59:42

标签: python tensorflow

我正在使用张量流,对tf.where函数感到困惑。

我们知道,该函数包含三个参数:cond,x和y。如果x和y都不为None,它将返回一个包含张量的张量,当cond为true时返回x的元素,而cond为false则返回y的元素。

但是,我发现代码的结果很混乱。代码如下:

import tensorflow as tf
import numpy as np
sess = tf.Session()

variable_bool1 = tf.convert_to_tensor(np.array([True, False]), dtype=tf.bool)
variable_bool2 = tf.convert_to_tensor(tf.Variable([True, False]), dtype=tf.bool)
variable_bool3 = tf.convert_to_tensor(tf.constant([True, False]), dtype=tf.bool)
variable_number1 = tf.convert_to_tensor(np.array([1.0, 2.0]), dtype=tf.float32)
variable_number2 = tf.convert_to_tensor(np.array([3.0, 4.0]), dtype=tf.float32)
where_result1 = tf.where(variable_bool1, variable_number1, variable_number2)
where_result2 = tf.where(variable_bool2, variable_number1, variable_number2)
where_result3 = tf.where(variable_bool3, variable_number1, variable_number2)
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())

print variable_bool1, sess.run(variable_bool1), sess.run([where_result1])
print variable_bool2, sess.run(variable_bool2), sess.run([where_result2])
print variable_bool3, sess.run(variable_bool3), sess.run([where_result3])

结果如下:

Tensor("Const:0", shape=(2,), dtype=bool) [ True False] [array([1., 4.], dtype=float32)]
Tensor("Variable/read:0", shape=(2,), dtype=bool) [ True False] [array([0., 0.], dtype=float32)]
Tensor("Const_1:0", shape=(2,), dtype=bool) [ True False] [array([1., 4.], dtype=float32)]

似乎variable_bool2无法返回预期结果。

是变量和常量之间的区别吗?

为什么variable_bool2总是返回零?

您能解释一下吗?非常感谢你!

0 个答案:

没有答案