如何在 PyTorch conv2d 函数中批量使用组参数?

时间:2021-04-09 07:17:52

标签: pytorch

按照How to use groups parameter in PyTorch conv2d function中的问题

我可以知道输入批次大小是否为 4,对于每个批次,它都有独立的过滤器来进行转换,我将代码修改如下,

import torch
import torch.nn.functional as F

filters = torch.autograd.Variable(torch.randn(3,4,3,3))
inputs = torch.autograd.Variable(torch.randn(4,3,10,10))
out = F.conv2d(inputs, filters, padding=1, groups=3)

我还有一个错误 运行时错误:给定组 = 3,大小为 [3, 4, 3, 3] 的权重,预期输入 [4, 3, 10, 10] 有 12 个通道,但得到了 3 个通道 如何解决?

1 个答案:

答案 0 :(得分:0)

当您有 shape (3,4,3,3) 过滤器时,预计通道数为 12

这应该有效

import torch
import torch.nn.functional as F
inputs = torch.autograd.Variable(torch.randn(3,12,10,10))
filters = torch.autograd.Variable(torch.randn(3,4,3,3))
out = F.conv2d(inputs, filters, padding=1, groups=3)