
时间:2018-06-09 11:26:59

标签: c++ tensorflow



这包含我们获取缩放值的方式 https://github.com/google/gemmlowp/blob/master/doc/quantization_example.cc

这包含基本矩阵乘法代码 https://github.com/google/gemmlowp/blob/master/test/test.cc#L35

#include <cstdio>
#include <utility>
#include <limits>
#include <cctype>
#include <iostream>
#include <random>
#include <iomanip>
#include <cassert>
#include <cmath>

template<class dtype>
class Mat{
    const int rows, cols;
    dtype *array;

    Mat(int rows, int cols) : cols(cols), rows(rows){
        array = new dtype[cols * rows];

        delete[] array;

    void randInit()

        for(int i = 0; i < rows; ++i)
            for(int j = 0; j < cols; ++j){
                float r = static_cast <float> (rand()) /
                    static_cast <float> (RAND_MAX);
                (*this)(i, j) = (dtype) (r * 2 - 1);

    void shift_right(int shift_factor) const{
        for(int i = 0; i < rows * cols; i++)
            array[i] >>= shift_factor;

    dtype& operator[](int idx) const{
        return array[idx];

    dtype& operator()(int r, int c) const{
        return array[cols * r + c];

    void mul(const dtype scalar){
        for(int size = rows * cols - 1; size >= 0; --size)
            array[size] *= scalar;

    void add(const dtype scalar){
        for(int size = rows * cols - 1; size >= 0; --size)
            array[size] += scalar;

    void sub(const dtype scalar){
        for(int size = rows * cols - 1; size >= 0; --size)
            array[size] -= scalar;

    Mat operator*(const Mat& other){
        if(cols != other.rows){
            throw "Incorrect dimensions";
        Mat result(rows, other.cols);

        for(int i = 0; i < rows; ++i){
            for(int j = 0; j < other.cols; ++j){
                for(int k = 0; k < cols; ++k){
                    result(i, j) += (*this)(i, k) * other(k, j);
        return result;

    std::pair<dtype, dtype> findMinMax(){
        std::pair <dtype, dtype> result(

        for(int i = rows * cols - 1; i >= 0; --i)
            if(result.first > array[i])
                result.first = array[i];

            if(result.second < array[i])
                result.second = array[i];

        return result;

    template<typename dtype2>
    Mat& operator=(const Mat<dtype2>& x){

        for(int i = rows * cols - 1; i >= 0; --i)
            array[i] = static_cast<dtype>(x[i]);

        return *this;

    template<typename dtype2>
    Mat(const Mat<dtype2> &p2) : rows(p2.rows), cols(p2.cols)
        array = new dtype[rows * cols];
        *(this) = p2;


template<class dtype>
std::ostream &operator<<(std::ostream &os, const Mat<dtype>& mat)
    for(int i = 0; i < mat.rows; ++i){
        for(int j = 0; j < mat.cols; ++j){
            os <<std::setw(8)<<std::setprecision(3)<< mat(i, j) <<'\t';
        os << std::endl;
    return os;

std::ostream &operator<<(std::ostream &os, const Mat<std::uint8_t>& mat)
    for(int i = 0; i < mat.rows; ++i){
        for(int j = 0; j < mat.cols; ++j){
            os <<std::setw(8)<<std::setprecision(3)<< int(mat(i, j)) <<'\t';
        os << std::endl;
    return os;

struct OutputStageQuantizeDownInt32ToUint8Scale
    std::int32_t result_offset;
    std::int32_t result_mult_int;
    std::int32_t result_shift;

struct QuantizationParams {
  float scale;
  std::uint8_t zero_point;

void QuantizeMultiplierSmallerThanOne(float real_multiplier,
                                      std::int32_t* quantized_multiplier,
                                      int* right_shift) {
  assert(real_multiplier > 0.f);
  assert(real_multiplier < 1.f);
  int s = 0;
  // We want to bring the real multiplier into the interval [1/2, 1).
  // We can do so by multiplying it by two, and recording how many times
  // we multiplied by two so that we can compensate that by a right
  // shift by the same amount.
  while (real_multiplier < 0.5f) {
    real_multiplier *= 2.0f;
  // Now that the real multiplier is in [1/2, 1), we convert it
  // into a fixed-point number.
  std::int64_t q =
      static_cast<std::int64_t>(std::round(real_multiplier * (1ll << 31)));
  assert(q <= (1ll << 31));
  // Handle the special case when the real multiplier was so close to 1
  // that its fixed-point approximation was undistinguishable from 1.
  // We handle this by dividing it by two, and remembering to decrement
  // the right shift amount.
  if (q == (1ll << 31)) {
    q /= 2;
  assert(s >= 0);
  assert(q <= std::numeric_limits<std::int32_t>::max());
  *quantized_multiplier = static_cast<std::int32_t>(q);
  *right_shift = s;

QuantizationParams ChooseQuantizationParams(float min, float max) {
    min = std::min(min, 0.f);
    max = std::max(max, 0.f);

    // the min and max quantized values, as floating-point values
    const float qmin = 0;
    const float qmax = 255;

    // First determine the scale.
    const double scale = (max - min) / (qmax - qmin);

    const double initial_zero_point = qmin - min / scale;

    std::uint8_t nudged_zero_point = 0;
    if (initial_zero_point < qmin) {
        nudged_zero_point = qmin;
    } else if (initial_zero_point > qmax) {
        nudged_zero_point = qmax;
    } else {
        nudged_zero_point =

    QuantizationParams result;
    result.scale = scale;
    result.zero_point = nudged_zero_point;
    return result;

void Quantize(const QuantizationParams& qparams, const Mat<float>& src,
              Mat<std::uint8_t>& dst) {

  const size_t size = src.cols * src.rows;
  assert(size == dst.cols * dst.rows);

  for (std::size_t i = 0; i < size; i++) {
    const float real_val = src[i];
    const float transformed_val = qparams.zero_point + real_val / qparams.scale;
    const float clamped_val = std::max(0.f, std::min(255.f, transformed_val));
    dst[i] = static_cast<std::uint8_t>(std::round(clamped_val));

void Dequantize(const QuantizationParams& qparams,
                const Mat<std::uint8_t>& src, 
                Mat<float>& dst) {
    const size_t size = src.cols * src.rows;
    assert(size == dst.cols * dst.rows);

    for (std::size_t i = 0; i < size; i++) {
        const std::uint8_t quantized_val = src[i];
        dst[i] = qparams.scale * (quantized_val - qparams.zero_point);

int main()
    Mat<float> lhs(2, 4);   lhs.randInit();
    Mat<float> rhs(4, 3);   rhs.randInit();
    Mat<float> reference_result = lhs * rhs;

    Mat<std::uint8_t> uint8_lhs(2, 4);
    Mat<std::uint8_t> uint8_rhs(4, 3);

    auto lhs_min_max = lhs.findMinMax();
    auto rhs_min_max = rhs.findMinMax();
    auto result_min_max = reference_result.findMinMax();

    const auto lhs_qparams = ChooseQuantizationParams(
        lhs_min_max.first, lhs_min_max.second);
    const auto rhs_qparams = ChooseQuantizationParams(
        rhs_min_max.first, rhs_min_max.second);
    const auto result_qparams = ChooseQuantizationParams(
        result_min_max.first, result_min_max.second);

    Quantize(lhs_qparams, lhs, uint8_lhs);
    Quantize(rhs_qparams, rhs, uint8_rhs);

    const int lhs_offset = -lhs_qparams.zero_point;
    const int rhs_offset = -rhs_qparams.zero_point;
    const int result_offset = result_qparams.zero_point;


    const float real_multiplier =
        lhs_qparams.scale * rhs_qparams.scale / result_qparams.scale;
    std::int32_t quantized_multiplier;
    int right_shift;
    QuantizeMultiplierSmallerThanOne(real_multiplier, &quantized_multiplier,

    std::cout << "End of OFFLINE QUANTIZATION CODE.\n" << std::endl;

    std::cout << "The below is ON-DEVICE RUNTIME QUANTIZED CODE. "
                << "This is the part that is performance-critical and may only "
                << "use quantized arithmetic.\n"
                << std::endl;

    OutputStageQuantizeDownInt32ToUint8Scale quantize_down_stage;
    quantize_down_stage.result_offset = result_offset;
    quantize_down_stage.result_mult_int = quantized_multiplier;
    quantize_down_stage.result_shift = right_shift;

    std::cout<<"QUANTIZED MULTIPLIER\t"<<quantized_multiplier<<std::endl;

    Mat<int> int_lhs = uint8_lhs;
    Mat<int> int_rhs = uint8_rhs;


    Mat<int> int_result = int_lhs * int_rhs;
    int_result.add(quantize_down_stage.result_shift < 1 ?
        0 : (1 << (quantize_down_stage.result_shift - 1)));

    std::cout<<"INT32 RESULT"<<std::endl;

    return 0;

0 个答案:
