使用 tf.GradientTape 的 TensorFlow 2 量化感知训练 (QAT)

时间:2021-03-31 00:12:17

标签: tensorflow keras quantization

谁能指出可以学习如何在 TensorFlow 2 上使用 tf.GradientTape 执行量化感知训练 (QAT) 的参考资料?

我只看到使用 tf.keras API 完成此操作。我不使用 tf. keras,我总是使用 tf.GradientTape 构建定制培训,以便更好地控制培训过程。我现在需要量化一个模型,但我只看到有关如何使用 tf. keras API 进行量化的参考资料。

1 个答案:

答案 0 :(得分:1)

在官方示例 here 中,他们展示了使用 model. fit 进行的 QAT 训练。以下是使用 tf.GradientTape()量化感知训练演示。但是为了完整的参考,让我们在这里都做。


基础模型训练。这直接来自official doc。有关更多详细信息,请查看此处。

import os
import tensorflow as tf
from tensorflow import keras
import tensorflow_model_optimization as tfmot

# Load MNIST dataset
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()

# Normalize the input image so that each pixel value is between 0 to 1.
train_images = train_images / 255.0
test_images = test_images / 255.0

# Define the model architecture.
model = keras.Sequential([
  keras.layers.InputLayer(input_shape=(28, 28)),
  keras.layers.Reshape(target_shape=(28, 28, 1)),
  keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation='relu'),
  keras.layers.MaxPooling2D(pool_size=(2, 2)),
  keras.layers.Flatten(),
  keras.layers.Dense(10)
])

# Train the digit classification model
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])
model.summary()
model.fit(
  train_images,
  train_labels,
  epochs=1,
  validation_split=0.1,
)
10ms/step - loss: 0.5411 - accuracy: 0.8507 - val_loss: 0.1142 - val_accuracy: 0.9705
<tensorflow.python.keras.callbacks.History at 0x7f9ee970ab90>

QAT .fit

现在,对基本模型执行QAT

# -----------------------
# ------------- Quantization Aware Training -------------
import tensorflow_model_optimization as tfmot

quantize_model = tfmot.quantization.keras.quantize_model
# q_aware stands for for quantization aware.
q_aware_model = quantize_model(model)

# `quantize_model` requires a recompile.
q_aware_model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

q_aware_model.summary()
train_images_subset = train_images[0:1000] 
train_labels_subset = train_labels[0:1000]
q_aware_model.fit(train_images_subset, train_labels_subset,
                  batch_size=500, epochs=1, validation_split=0.1)


356ms/step - loss: 0.1431 - accuracy: 0.9629 - val_loss: 0.1626 - val_accuracy: 0.9500
<tensorflow.python.keras.callbacks.History at 0x7f9edf0aef90>

检查性能

_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

_, q_aware_model_accuracy = q_aware_model.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)

Baseline test accuracy: 0.9660999774932861
Quant test accuracy: 0.9660000205039978

QAT tf.GradientTape()

这是基础模型的 QAT 训练部分。请注意,我们还可以对基础模型执行自定义训练。

batch_size = 500

train_dataset = tf.data.Dataset.from_tensor_slices((train_images_subset,
                                                     train_labels_subset))
train_dataset = train_dataset.batch(batch_size=batch_size, 
                                    drop_remainder=False)

loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = tf.keras.optimizers.Adam()

for epoch in range(1):
    for x, y in train_dataset:
        with tf.GradientTape() as tape:
            preds = q_aware_model(x, training=True)
            loss = loss_fn(y, preds)
        grads = tape.gradient(loss, q_aware_model.trainable_variables)
        optimizer.apply_gradients(zip(grads, q_aware_model.trainable_variables))
        
_, baseline_model_accuracy = model.evaluate(
    test_images, test_labels, verbose=0)

_, q_aware_model_accuracy = q_aware_model.evaluate(
   test_images, test_labels, verbose=0)

print('Baseline test accuracy:', baseline_model_accuracy)
print('Quant test accuracy:', q_aware_model_accuracy)
Baseline test accuracy: 0.9660999774932861
Quant test accuracy: 0.9645000100135803