根据Tensorflow

时间:2018-05-14 23:06:28

标签: python tensorflow

我有两个向量:时间和事件。如果一个事件为1,则应将同一索引处的时间分配给func_for_event1。否则,它会转到func_for_event0

import tensorflow as tf

def func_for_event1(t):
    return t + 1

def func_for_event0(t):
    return t - 1

time = tf.placeholder(tf.float32, shape=[None])  # [3.2, 4.2, 1.0, 1.05, 1.8]
event = tf.placeholder(tf.int32, shape=[None])  # [0, 1, 1, 0, 1]

# result: [2.2, 5.2, 2.0, 0.05, 2.8]
# For example, 3.2 should be sent to func_for_event0 because the first element in event is 0.

我应该如何在Tensorflow中实现这个逻辑?说tf.condtf.where

1 个答案:

答案 0 :(得分:1)

这正是tf.where()的用途。此代码(已测试):

import tensorflow as tf
import numpy as np

def func_for_event1(t):
    return t + 1

def func_for_event0(t):
    return t - 1

time = tf.placeholder(tf.float32, shape=[None])  # [3.2, 4.2, 1.0, 1.05, 1.8]
event = tf.placeholder(tf.int32, shape=[None])  # [0, 1, 1, 0, 1]

result = tf.where( tf.equal( 1, event ), func_for_event1( time ), func_for_event0( time ) )
# result: [2.2, 5.2, 2.0, 0.05, 2.8]
# For example, 3.2 should be sent to func_for_event0 because the first element in event is 0.

with tf.Session() as sess:
    res = sess.run( result, feed_dict = {
        time : np.array( [3.2, 4.2, 1.0, 1.05, 1.8] ),
        event : np.array( [0, 1, 1, 0, 1] )
    } )
    print ( res )

输出:

  

[2.2 5.2 2. 0.04999995 2.8]

根据需要。