关于T.gt()的错误

时间:2015-08-01 03:39:34

标签: python if-statement theano

为了获得T.gt()的用法,我写了一个玩具代码。

def f(data):
    # return T.gt(data, 0)
    if T.gt(data, 0):
        print "1"
        return -data
    else:
        print "2"
        return data


a = T.scalar()
t = f(a)

print t.eval({a:-4})

我希望当a = -4时返回值为-4,当a = 4时返回-4,但它总是满足条件并运行return -data。 我不知道为什么。你能救我吗?

1 个答案:

答案 0 :(得分:2)

T.gt是一个符号函数;它不会返回一个布尔值,而是返回一个表示符号表达式的对象,该符号表达式在以后编译和执行时将计算为布尔值。

因此,在Python中,T.gt(...)将始终评估为True,因为结果始终为非None

如果您想在Theano中使用条件表达式,则需要使用符号条件运算。有两个:T.switchtheano.ifelse.ifelse。区别在于T.switch是一个逐元素的操作,接受一个张量条件,而ifelse需要一个标量条件。

您的示例还有另一个问题。即使代码很好,它也总会返回负值。实质上你的例子说,如果输入是肯定的,则返回其负数,否则返回输入原样(已经是负数)。我还建议在theano.function函数上使用eval

您的示例可能会被更改以说明ifelse的工作方式:

import theano
import theano.ifelse
import theano.tensor as T


def symbolic_f(x):
    return theano.ifelse.ifelse(T.gt(x, 0), -x - 1, x + 1)


def main():
    x = T.scalar()
    f = theano.function(inputs=[x], outputs=symbolic_f(x))

    print f(-4)
    print f(4)


main()