pytorch中无效的参数组合

时间:2018-03-16 13:45:59

标签: python pytorch

我在识别错误​​方面遇到了一些麻烦。 这是我收到的错误消息:

  

返回torch.exp( - ((x-2)** 2))+ 0.8 * torch.exp( - (x + 2)** 2)

     

TypeError:torch.exp接收到无效的参数组合 - got(!float!),但是期望(torch.FloatTensor source)

import torch
import time
dtype = torch.FloatTensor

def functionFit(x):
    return torch.exp(-((x-2)**2)) + 0.8*torch.exp(-(x+2)**2)

def createSample(mu,sig,N):
    return  mu + sig*torch.randn(N,1).type(dtype)

def updateMU(alpha,rho,x,N,I,mu):
    return alpha*torch.mean(x[I[int((1-rho)*N):N,0]]) + (1-alpha)*mu

def updateSIG(alpha,rho,x,N,I,sig):
    return alpha*torch.std(x[I[int((1-rho)*N):N,0]])  + (1-alpha)*sig  

def CE(N,rho,alpha,epsilon,mu,sig): # initial std dev.
    start = time.time()
    k = 0
    while (sig > epsilon):
        x            = createSample(mu,sig,N)
        S            = functionFit(x)
        sorted_v , I = torch.sort(S,0)
        mu           = updateMU(alpha,rho,x,N,I,mu)
        sig          = updateSIG(alpha,rho,x,N,I,sig)
        k = k + 1

    end = time.time()
    xm  = torch.mean(x)
    ym  = functionFit(xm)
    print('x =',xm)
    print('y =',ym)
    print('time =',end - start,'s')
    print('iter =',k)


if __name__ == '__main__':
    N       = 50
    rho     = 0.5
    alpha   = 0.9
    epsilon = 0.001
    mu      = 20*torch.rand(1,1).type(dtype)-10 # init mu
    sig     = 5
    CE(N,rho,alpha,epsilon,mu,sig)

1 个答案:

答案 0 :(得分:1)

CE函数中,以下两行导致错误。

xm  = torch.mean(x)
ym  = functionFit(xm)

此处x50 x 1 FloatTensor,但是当您调用torch.mean()时,它会返回一个浮点值,当您调用functionFit(xm)时会导致错误。

顺便说一句,仅为了您的信息,torch.mean()返回一个浮点值,torch.exp()期望一个张量作为输入。您可以在functionFit()中简单地检查参数的类型,如果参数是浮点值而不是张量,则使用numpy计算指数。