Tensorflow,tf.case()vs tf.cond(),逻辑运算

时间:2017-12-06 16:08:59

标签: python tensorflow

我有2个张量X和Y.

X = tf.constant([[-1,-2,-3],[4,5,6]], dtype=tf.float32)
Y = tf.constant([[-2,2,2],[-2,2,2]], dtype=tf.float32)

我想要的操作是:

if((X> 0和Y> 0)或(X <0和Y <0)):X + Y 否则:X-Y

我在tf.case上跟踪了示例:

示例2: 伪代码:

if (x < y && x > z) raise OpError("Only one predicate may evaluate true");
if (x < y) return 17;
else if (x > z) return 23;
else return -1;

表达式:

def f1(): return tf.constant(17)
def f2(): return tf.constant(23)
def f3(): return tf.constant(-1)
r = case({tf.less(x, y): f1, tf.greater(x, z): f2},
         default=f3, exclusive=True)

tf.case()无法工作的原因是条件必须是标量张量,因此它不支持2D张量。

Each pair contains a boolean scalar tensor and a python callable 

尝试使用tf.cond()并失败:

import tensorflow as tf

a = tf.constant([[-1,-2,-3],[4,5,6]], dtype=tf.float32)
b = tf.constant([[-2,2,2],[-2,2,2]], dtype=tf.float32)

zero = tf.fill(tf.shape(a), 0.0)

def f1(): return tf.add(a,b)
def f2(): return zero

c = tf.cond(tf.less(a,b), f1, f2)

with tf.Session() as session:
    r_c = session.run([c])
    print (r_c)

这是错误报告:

    Shape must be rank 0 but is rank 2 for 'cond/Switch' 
(op: 'Switch') with input shapes: [2,3], [2,3].

0 个答案:

没有答案