tf.while_loop解决高维tf.where问题?

时间:2019-03-27 20:56:40

标签: python tensorflow

我正在尝试训练一个神经网络,该神经网络在不同的预定区域上输出不同的概率。我在编写损失函数-对数似然函数时遇到麻烦。

更具体地说,我拥有的数据的格式为(X,t),其中X是输入要素,t是响应变量-正标量。我想估计给定X的t的概率密度。

我使用的方法如下:我根据矢量Tau将正实线划分为间隔。例如,如果Tau = [0,5,10],那么我将正实线划分为[0,5),[5,10)和[10,\ infty)。现在我想训练一个神经网络来输出p_1,p_2,p_3,这表明t处于这三个间隔之一中的概率。

数据点(X,t = 2)的似然函数将简单地为p_1,数据点(X,t = 15)的似然函数为p_2。这可以通过对单个数据点使用tf.where来完成。但我不知道如何批量进行此操作。

现在,当使用占位符进行批次梯度下降时,我从神经网络输出的张量P将具有(?,3)的形状。 “?”是未指定的批量大小。并且该批次T中的数据将具有形状(?,1)。现在,对于批处理中的每个数据点,我必须找出该数据点属于哪个区域,并相应地确定该数据点的似然函数。我尝试过tf.cond或tf.case,但我知道它不支持广播。因此,我的最后希望是tf.where。在我看来,这是在tf.where中加上“?”时间,但我不知道具体如何。任何帮助将非常感激。

我尝试了以下使用tf.while_loop和tf.where的代码,但是我已经遇到了ValueError。在示例代码中,Tau定义了时间间隔,而T是响应变量的示例批次。还没有神经网络。我只是想对这个玩具问题进行编码,以查看它是否有效。

graph = tf.Graph()
with graph.as_default():
    Tau = [0, 10, 20, 30,40]
    i = tf.Variable(0)
    T = tf.convert_to_tensor(  np.array((12,24) ) )
    Tau_tf = tf.convert_to_tensor(Tau)

    def cond(i, T, Tau_tf):
        return tf.less(i, Tau_tf.shape[0])
    def body(i, T, Tau_tf):
        i_fill = tf.fill((2,), i+1)
        index = tf.where( (T > Tau_tf[i]) & (T < Tau_tf[i+1]), i_fill, T )
        tf.add(i,1)
        return index
    index = tf.while_loop(cond, body, [i, T, Tau_tf])

with tf.Session() as sess:
    tf.global_variables_initializer().run()
    index = sess.run(index)
    print(index)

在这段代码中,我想输出index = [2,3],因为T = [12,24],而12落入第二个间隔,而24落入第三个间隔。

0 个答案:

没有答案