我尝试使用以下代码来确定CNN模型正向传递所需的浮点操作数。
对于已经非常稀疏(90%零)量化的类似模型,我希望所需的FLOPS数量会少得多,但是与原始模型相比,我得到的FLOPS数量相同。
如何为稀疏模型获取FLOPS,或者为什么值不发生变化?谢谢
def count_flops(model, input_image_size):
# flops count from each layer
counts = []
# loop over all model parts
for m in model.modules():
if isinstance(m, nn.Conv2d):
def hook(module, input):
factor = 2*module.in_channels*module.out_channels
factor *= module.kernel_size[0]*module.kernel_size[1]
factor //= module.stride[0]*module.stride[1]
counts.append(
factor*input[0].data.shape[2]*input[0].data.shape[3]
)
m.register_forward_pre_hook(hook)
elif isinstance(m, nn.Linear):
counts += [
2*m.in_features*m.out_features
]
noise_image = torch.rand(
1, 3, input_image_size, input_image_size
)
# one forward pass
_ = model(Variable(noise_image.cuda(), volatile=True))
return sum(counts)