我有2个张量X和Y.
X = tf.constant([[-1,-2,-3],[4,5,6]], dtype=tf.float32)
Y = tf.constant([[-2,2,2],[-2,2,2]], dtype=tf.float32)
我想要的操作是:
if((X> 0和Y> 0)或(X <0和Y <0)):X + Y 否则:X-Y
我在tf.case上跟踪了示例:
示例2: 伪代码:
if (x < y && x > z) raise OpError("Only one predicate may evaluate true");
if (x < y) return 17;
else if (x > z) return 23;
else return -1;
表达式:
def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
r = case({tf.less(x, y): f1, tf.greater(x, z): f2},
default=f3, exclusive=True)
tf.case()无法工作的原因是条件必须是标量张量,因此它不支持2D张量。
Each pair contains a boolean scalar tensor and a python callable
尝试使用tf.cond()并失败:
import tensorflow as tf
a = tf.constant([[-1,-2,-3],[4,5,6]], dtype=tf.float32)
b = tf.constant([[-2,2,2],[-2,2,2]], dtype=tf.float32)
zero = tf.fill(tf.shape(a), 0.0)
def f1(): return tf.add(a,b)
def f2(): return zero
c = tf.cond(tf.less(a,b), f1, f2)
with tf.Session() as session:
r_c = session.run([c])
print (r_c)
这是错误报告:
Shape must be rank 0 but is rank 2 for 'cond/Switch'
(op: 'Switch') with input shapes: [2,3], [2,3].