在pytorch中有theano.tensor.switch的模拟吗?

时间:2018-04-19 23:35:36

标签: deep-learning theano pytorch

我想强制将向量中所有低于某个阈值的元素归零。而且我想这样做,这样我仍然可以通过非零值传播渐变。

例如,在theano我可以写:

B = theano.tensor.switch(A < .1, 0, A)

在pytorch中是否有解决方案?

2 个答案:

答案 0 :(得分:1)

我不认为PyTorch中默认实现了switch。但是,您可以通过extending the torch.autograd.Function

在PyTorch中定义自己的函数

因此,切换功能看起来像

class switchFunction(Function):
    @staticmethod
    def forward(ctx, flag, value, tensor):
        ctx.save_for_backward(flag)
        tensor[flag] = value
        return tensor

    @staticmethod
    def backward(ctx, grad_output):
        flag, = ctx.saved_variables
        grad_output[flag] = 0
        return grad_output
switch = switchFunction.apply

现在,您只需将switch称为switch(A < 0.1, 0, A)

即可

修改

实际上有一个功能可以做到这一点。它被称为Threshold。您可以像

一样使用它
import torch.nn as nn
m = nn.Threshold(0.1, 0)
B = m(A)

答案 1 :(得分:1)

从pytorch 0.4+开始,您可以使用torch.where轻松完成(请参阅docMerged PR

就像Theano一样容易。看看自己的例子:

import torch
from torch.autograd import Variable

x = Variable(torch.arange(0,4), requires_grad=True) # x     = [0 1 2 3]
zeros = Variable(torch.zeros(*x.shape))             # zeros = [0 0 0 0]

y = x**2                         # y = [0 1 4 9]
z = torch.where(y < 5, zeros, y) # z = [0 0 0 9]

# dz/dx = (dz/dy)(dy/dx) = (y < 5)(0) + (y ≥ 5)(2x) = 2x(x**2 ≥ 5) 
z.backward(torch.Tensor([1.0])) 
x.grad # (dz/dx) = [0 0 0 6]