我试图使用tf.case在张量中使用索引值来指向不同的网络结构部分,获得不同的损失,然后将它们总结为最终的训练损失。举一个简单的例子,我判断列表中的值并输出不同的值。例如 [0,1,2,3] - > [0,7,10,13]在哪里 案例0: 输出0 情况1: 输出7 案例2: 输出10 案例3: 输出13。 但是,tf.cond,tf.case似乎只用于标量。如何实现目标?
答案 0 :(得分:2)
我所知道的唯一一个对向量的每个元素分别评估条件的操作是tf.where。你会离开x=None, y=None
:
t_orig = tf.constant([0, 1, 2, 3, 1])
t_filt = tf.where(tf.equal(t_orig, 1))
with tf.Session() as sess:
print sess.run(t_filt)
输出:
[[1]
[4]]
然而,这只评估单一条件的真实性。如果您想要评估多个条件的真实性,对于向量的每个元素,我认为您必须使用tf.map_fn
与tf.case
结合使用。 AFAIK,tf.case
是评估给定值上许多条件真实性的唯一操作:
t_orig = tf.constant([0, 1, 2, 3])
t_new = tf.map_fn(
lambda x: tf.case(
pred_fn_pairs=[
(tf.equal(x, 0), lambda: tf.constant(0)),
(tf.equal(x, 1), lambda: tf.constant(7)),
(tf.equal(x, 2), lambda: tf.constant(10)),
(tf.equal(x, 3), lambda: tf.constant(13))],
default=lambda: tf.constant(-1)),
t_orig)
with tf.Session() as sess:
print sess.run(t_new)
输出:
[ 0 7 10 13]
答案 1 :(得分:0)
尝试
import tensorflow as tf
value = [0, 1, 2, 3]
ones = tf.ones_like(value)
out = tf.where(tf.equal(value, 0), ones * 0,
tf.where(tf.equal(value, 1), ones * 7,
tf.where(tf.equal(value, 2), ones * 10,
tf.where(tf.equal(value, 3), ones * 13, ones * -1
)
)
)
)
with tf.Session() as sess:
print(sess.run(out)) # [ 0 7 10 13]