使用自定义激活功能时出现分段错误

时间:2019-10-08 12:22:42

标签: pytorch

我正在尝试实现自定义激活功能(下面随附的代码)。在使用自定义激活功能之前,一切正常。但是,只要使用它,服务器就会抛出错误:

分段错误

错误总是出现在第一个时期。

我正在使用

Pytorch 1.1.0 Cuda编译工具,版本9.2,V9.2.148 代码

def mg(x):

    c = 1.33
    b = 0.4
    p = 6.88
    input_size = x.shape
    num = torch.numel(x) # the element number of the input tensor
    x = x.view(num)

    out = torch.zeros(len(x))

    for i in range(len(x)):
    if x[i] < 0:
            out[i] = 0
        else:
            out[i] = (c * x[i]) / (1 + torch.mul(b * p, torch.pow(x[i], p)))

    out = out.view(input_size[0], input_size[1], input_size[2], input_size[3])
    return out

1 个答案:

答案 0 :(得分:3)

您正在使用新创建的out打破渐变。

您应该修改代码以对x输入进行操作。另外,您不应该使用任何循环(几乎总是有一种方法可以不使用它们)。鉴于此,此功能应与您的功能相同,但可以起作用:

def mg(x, c=1.33, b=0.4, p=6.88):
    input_size = x.shape
    x = x.flatten()

    x[x < 0] = 0
    x[x != 0] *= c
    x[x != 0] /= 1 + b * p * x[x != 0] ** p

    return x.reshape(*input_size)

如果仍然出现错误,则可能与程序的其他部分有关。