是否有一种替代方法可以在张量上应用tf.case,tf.cond等?

时间:2017-01-17 02:50:53

标签: tensorflow

我试图使用tf.case在张量中使用索引值来指向不同的网络结构部分,获得不同的损失,然后将它们总结为最终的训练损失。举一个简单的例子,我判断列表中的值并输出不同的值。例如 [0,1,2,3] - > [0,7,10,13]在哪里 案例0: 输出0 情况1: 输出7 案例2: 输出10 案例3: 输出13。 但是,tf.cond,tf.case似乎只用于标量。如何实现目标?

2 个答案:

答案 0 :(得分:2)

我所知道的唯一一个对向量的每个元素分别评估条件的操作是tf.where。你会离开x=None, y=None

t_orig = tf.constant([0, 1, 2, 3, 1])
t_filt = tf.where(tf.equal(t_orig, 1))
with tf.Session() as sess:
    print sess.run(t_filt)

输出:

[[1]
 [4]]

然而,这只评估单一条件的真实性。如果您想要评估多个条件的真实性,对于向量的每个元素,我认为您必须使用tf.map_fntf.case结合使用。 AFAIK,tf.case是评估给定值上许多条件真实性的唯一操作:

t_orig = tf.constant([0, 1, 2, 3])
t_new = tf.map_fn(
        lambda x: tf.case(
            pred_fn_pairs=[
                (tf.equal(x, 0), lambda: tf.constant(0)),
                (tf.equal(x, 1), lambda: tf.constant(7)),
                (tf.equal(x, 2), lambda: tf.constant(10)),
                (tf.equal(x, 3), lambda: tf.constant(13))],
            default=lambda: tf.constant(-1)),
        t_orig)
with tf.Session() as sess:
    print sess.run(t_new)

输出:

[ 0  7 10 13]

答案 1 :(得分:0)

尝试

import tensorflow as tf
value = [0, 1, 2, 3]
ones = tf.ones_like(value)
out = tf.where(tf.equal(value, 0), ones * 0,
               tf.where(tf.equal(value, 1), ones * 7,
                        tf.where(tf.equal(value, 2), ones * 10,
                                 tf.where(tf.equal(value, 3), ones * 13, ones * -1
                                          )
                                 )
                        )
               )

with tf.Session() as sess:
    print(sess.run(out)) # [ 0  7 10 13]