我正在PyTorch中构建变体自动编码器(VAE),并且在编写设备不可知代码时遇到问题。自动编码器是nn.Module
的子级,也具有编码器和解码器网络。可以通过调用net.to(device)
将网络的所有权重从一台设备移动到另一台设备。
我遇到的问题是重新参数化技巧:
encoding = mu + noise * sigma
噪声是与mu
和sigma
相同大小的张量,并保存为自动编码器模块的成员变量。它在构造函数中初始化,并在每个训练步骤中就地重新采样。我这样做是为了避免在每个步骤中构造一个新的噪声张量并将其推入所需的设备。另外,我想修复评估中的噪音。这是代码:
class VariationalGenerator(nn.Module):
def __init__(self, input_nc, output_nc):
super(VariationalGenerator, self).__init__()
self.input_nc = input_nc
self.output_nc = output_nc
embedding_size = 128
self._train_noise = torch.randn(batch_size, embedding_size)
self._eval_noise = torch.randn(1, embedding_size)
self.noise = self._train_noise
# Create encoder
self.encoder = Encoder(input_nc, embedding_size)
# Create decoder
self.decoder = Decoder(output_nc, embedding_size)
def train(self, mode=True):
super(VariationalGenerator, self).train(mode)
self.noise = self._train_noise
def eval(self):
super(VariationalGenerator, self).eval()
self.noise = self._eval_noise
def forward(self, inputs):
# Calculate parameters of embedding space
mu, log_sigma = self.encoder.forward(inputs)
# Resample noise if training
if self.training:
self.noise.normal_()
# Reparametrize noise to embedding space
inputs = mu + self.noise * torch.exp(0.5 * log_sigma)
# Decode to image
inputs = self.decoder(inputs)
return inputs, mu, log_sigma
现在我用net.to('cuda:0')
将自动编码器移至GPU时,由于噪声张量未移动,因此转发时出错。
我不想将设备参数添加到构造函数中,因为那样以后仍然无法将其移动到另一个设备。我还尝试将噪声包装到nn.Parameter
中,以使其受到net.to()
的影响,但是由于噪声被标记为requires_grad=False
,所以这会带来优化器的错误。
任何人都有使用net.to()
移动所有模块的解决方案吗?
答案 0 :(得分:1)
tilman151's second approach的更好版本可能是覆盖_apply
,而不是to
。这样,net.cuda()
,net.float()
等也都可以工作,因为它们都调用_apply
而不是to
(在the source中可以看到,比您想象的要简单):
def _apply(self, fn):
super(VariationalGenerator, self)._apply(fn)
self._train_noise = fn(self._train_noise)
self._eval_noise = fn(self._eval_noise)
return self
答案 1 :(得分:1)
通过使用它,您可以对张量和模块应用相同的参数
def to(self, **kwargs):
module = super(VariationalGenerator, self).to(**kwargs)
module._train_noise = self._train_noise.to(**kwargs)
module._eval_noise = self._eval_noise.to(**kwargs)
return module
答案 2 :(得分:0)
使用此:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
现在同时使用模型和每个张量
net.to(device)
input = input.to(device)
答案 3 :(得分:0)
经过反复试验,我发现了两种方法:
self._train_noise = torch.randn(batch_size, embedding_size)
替换为self.register_buffer('_train_noise', torch.randn(batch_size, embedding_size)
,噪声张量将作为缓冲区添加到模块中。这也使net.to(device)
对其产生影响。此外,张量现在是state_dict的一部分。覆盖net.to(device)
:使用此选项,噪声不会进入state_dict。
def to(device):
new_self = super(VariationalGenerator, self).to(device)
new_self._train_noise = new_self._train_noise.to(device)
new_self._eval_noise = new_self._eval_noise.to(device)
return new_self