我们是否可以修剪预先训练的模型?示例:MobileNetV2

时间:2020-05-25 12:31:55

标签: python tensorflow keras pruning

我正在尝试修剪一个经过预先训练的模型: MobileNetV2 ,但出现此错误。尝试过在线搜索,无法理解。我正在 Google Colab 上运行。

这些是我的进口商品。

import tensorflow as tf
import tensorflow_model_optimization as tfmot
import tensorflow_datasets as tfds
from tensorflow import keras

import os
import numpy as np
import matplotlib.pyplot as plt
import tempfile
import zipfile

这是我的代码。

model_1 = keras.Sequential([
    basemodel,
    keras.layers.GlobalAveragePooling2D(),
    keras.layers.Dense(1)                            
])

model_1.compile(optimizer='adam',
                loss=keras.losses.BinaryCrossentropy(from_logits=True),
                metrics=['accuracy'])

model_1.fit(train_batches,
            epochs=5,
            validation_data=valid_batches)

prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude

pruning_params = {
    'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,
                                                             final_sparsity=0.80,
                                                             begin_step=0,
                                                             end_step=end_step)
}


model_2 = prune_low_magnitude(model_1, **pruning_params)

model_2.compile(optmizer='adam',
                loss=keres.losses.BinaryCrossentropy(from_logits=True),
                metrics=['accuracy'])

这是我得到的错误。

---> 12 model_2 = prune_low_magnitude(model, **pruning_params)

ValueError: Please initialize `Prune` with a supported layer. Layers should either be a `PrunableLayer` instance, or should be supported by the PruneRegistry. You passed: <class 'tensorflow.python.keras.engine.training.Model'>

3 个答案:

答案 0 :(得分:0)

我相信您正在关注... ... const {StringMaxLen20Type, StringMaxLen25Type, StringMaxLen50Type, StringMaxLen255Type, StringMaxLen500Type } = require('StringMaxLenTypes.js');``` ... module.exports = { Query: {...}, Mutation: {...}, StringMaxLen20Type, StringMaxLen25Type, StringMaxLen50Type, StringMaxLen255Type, StringMaxLen500Type 并跳至Pruning in Keras Example部分,而未设置可修剪层。您必须重新实例化模型并设置要设置为Fine-tune pre-trained model with pruning的图层。请按照本指南获取有关如何设置可修剪图层的更多信息。

https://www.tensorflow.org/model_optimization/guide/pruning/comprehensive_guide.md

答案 1 :(得分:0)

我在以下方面遇到了同样的问题

  • tensorflow版本:2.2.0

只需将tensorflow的版本更新为2.3.0即可解决此问题,我认为Tensorflow在2.3.0中对此功能添加了支持。

答案 2 :(得分:0)

我发现的一件事是,我添加到模型中的实验性预处理引发了此错误。我在模型的开头使用了这个来帮助添加更多的训练样本,但是 keras 修剪代码不喜欢这样的子类模型。同样,代码也不喜欢像我那样对图像进行居中的实验性预处理。从模型中删除预处理为我解决了这个问题。

def classificationModel(trainImgs, testImgs):
  L2_lambda = 0.01
  data_augmentation = tf.keras.Sequential(
  [ layers.experimental.preprocessing.RandomFlip("horizontal", input_shape=IM_DIMS),
    layers.experimental.preprocessing.RandomRotation(0.1),
    layers.experimental.preprocessing.RandomZoom(0.1),])

  model = tf.keras.Sequential()
  model.add(data_augmentation)
  model.add(layers.experimental.preprocessing.Rescaling(1./255, input_shape=IM_DIMS))
...