如何使用迁移学习模型进行剪枝?

时间:2021-06-25 09:28:50

标签: tensorflow keras tensorflow-lite transfer-learning pruning

本质上,我想对我的迁移学习模型进行修剪。

我使用 efficientnetb0 对微生物进行分类。

import tensorflow_model_optimization as tfmot

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

# Compute end step to finish pruning after 2 epochs.
batch_size = 32
epochs = 25

end_step = np.ceil(len(training_set) / batch_size).astype(np.int32) * epochs

# Define model for pruning.
pruning_params = {
      'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(
                                        initial_sparsity = 0.40,                                                                 
                                        final_sparsity = 0.90,                                                                   
                                        begin_step = 0,                                                                
                                        end_step = end_step
                                        )
                  }

model_for_pruning = prune_low_magnitude(
                         efficientnetb0_transfer_model, **pruning_params)

# `prune_low_magnitude` requires a recompile.
efficientnetb0_transfer_model_for_pruning.compile(optimizer=optim,
              loss='categorical_crossentropy',
              metrics=['accuracy'])

efficientnetb0_transfer_model_for_pruning.summary()

但是,我收到以下错误:

ValueError: Please initialize `Prune` with a supported layer. Layers should either be supported by the PruneRegistry (built-in keras layers) or should be a `PrunableLayer` instance, or should has a customer defined `get_prunable_weights` method. You passed: <class 'tensorflow.python.keras.layers.preprocessing.image_preprocessing.Rescaling'>

我可能做错了什么?

1 个答案:

答案 0 :(得分:1)

您正在点击这个 error

修剪 API 不够灵活。它目前期望模型中的所有层都是可修剪的(逻辑 here)。理想情况下,它应该能够跳过像图像重新缩放这样的层。您可以提交 github 问题,我们将进行修复。谢谢!