在Pytorch中为稀疏的CNN模型计算FLOPS

时间:2018-08-09 04:20:45

标签: sparse-matrix conv-neural-network pytorch flops

我尝试使用以下代码来确定CNN模型正向传递所需的浮点操作数。

对于已经非常稀疏(90%零)量化的类似模型,我希望所需的FLOPS数量会少得多,但是与原始模型相比,我得到的FLOPS数量相同。

如何为稀疏模型获取FLOPS,或者为什么值不发生变化?谢谢

def count_flops(model, input_image_size):

# flops count from each layer
counts = []

# loop over all model parts
for m in model.modules():
if isinstance(m, nn.Conv2d):
    def hook(module, input):
        factor = 2*module.in_channels*module.out_channels
        factor *= module.kernel_size[0]*module.kernel_size[1]
        factor //= module.stride[0]*module.stride[1]
        counts.append(
            factor*input[0].data.shape[2]*input[0].data.shape[3]
        )
    m.register_forward_pre_hook(hook)
elif isinstance(m, nn.Linear):
    counts += [
        2*m.in_features*m.out_features
    ]

noise_image = torch.rand(
1, 3, input_image_size, input_image_size
)
# one forward pass
_ = model(Variable(noise_image.cuda(), volatile=True))
return sum(counts)

0 个答案:

没有答案