theano和lambda函数

时间:2016-07-22 07:01:10

标签: python lambda theano

当我有一个评估theano表达式的lambda函数列表时,我有一些奇怪的行为。代码如下:

# Equivalent functions (or at least I assume so)
def tilted_loss(q,y,f):
    e = (y-f)
    return (q*tt.sum(e)-tt.sum(e[(e<0).nonzero()]))/e.shape[0]

def tilted_loss2(y,f):
    q = 0.05
    e = (y-f)
    return (q*tt.sum(e)-tt.sum(e[(e<0).nonzero()]))/e.shape[0]

def tilted_loss_np(q,y,f):
    e = (y-f)
    return (q*sum(e)-sum(e[e<0]))/e.shape[0]

# lambda functions which uses above functions
qs = np.arange(0.05,1,0.05)
q_loss_f = [lambda y,f: tilted_loss(q,y,f) for q in qs]
q_loss_f2 = lambda y,f:tilted_loss(0.05,y,f)
q_loss_f3 = lambda y,f:tilted_loss(qs[0],y,f)

# Test the functions
np.random.seed(1)
a = np.random.randn(1000,1)
b = np.random.randn(1000,1)
print(q_loss_f[0](a,b).eval())
print(q_loss_f2(a,b).eval())
print(q_loss_f3(a,b).eval())
print(tilted_loss2(a,b).eval())
print(tilted_loss_np(qs[0],a,b)[0])

这给出了输出:

0.571973847658054
0.5616355181780912
0.5616355181695327
0.5616355181780912
0.56163551817
  1. 我必须对定义函数列表q_loss_f的方式做错。
  2. q定义的方式好吗?即我发送的一个numpy变量,但q_loss_f3似乎没问题。
  3. 有什么想法吗?

1 个答案:

答案 0 :(得分:1)

是常见错误,lambda expresion中的q值只取理解循环中的最后一个值,最好使用partial

q_loss_f = [partial(tilted_loss, q=q) for q in qs]