我有两个向量:时间和事件。如果一个事件为1,则应将同一索引处的时间分配给func_for_event1
。否则,它会转到func_for_event0
。
import tensorflow as tf
def func_for_event1(t):
return t + 1
def func_for_event0(t):
return t - 1
time = tf.placeholder(tf.float32, shape=[None]) # [3.2, 4.2, 1.0, 1.05, 1.8]
event = tf.placeholder(tf.int32, shape=[None]) # [0, 1, 1, 0, 1]
# result: [2.2, 5.2, 2.0, 0.05, 2.8]
# For example, 3.2 should be sent to func_for_event0 because the first element in event is 0.
我应该如何在Tensorflow中实现这个逻辑?说tf.cond
或tf.where
?
答案 0 :(得分:1)
这正是tf.where()
的用途。此代码(已测试):
import tensorflow as tf
import numpy as np
def func_for_event1(t):
return t + 1
def func_for_event0(t):
return t - 1
time = tf.placeholder(tf.float32, shape=[None]) # [3.2, 4.2, 1.0, 1.05, 1.8]
event = tf.placeholder(tf.int32, shape=[None]) # [0, 1, 1, 0, 1]
result = tf.where( tf.equal( 1, event ), func_for_event1( time ), func_for_event0( time ) )
# result: [2.2, 5.2, 2.0, 0.05, 2.8]
# For example, 3.2 should be sent to func_for_event0 because the first element in event is 0.
with tf.Session() as sess:
res = sess.run( result, feed_dict = {
time : np.array( [3.2, 4.2, 1.0, 1.05, 1.8] ),
event : np.array( [0, 1, 1, 0, 1] )
} )
print ( res )
输出:
[2.2 5.2 2. 0.04999995 2.8]
根据需要。