我复制张量流的低精度矩阵乘法有什么问题

时间:2018-06-09 11:26:59

标签: c++ tensorflow

我正在关注Google的gemmlowp代码和文档,以尝试了解它们如何量化神经网络权重并仍然执行矩阵乘法等操作。我意识到他们所做的就是从他们想要乘法的矩阵中减去一些偏差,然后乘以它们。最后,他们添加一些偏移,舍入,然后乘以一个固定点整数,将值重新缩放回[0,255]范围。最终结果大部分时间都没有在该范围内,所以钳位只会产生一个只包含0和255的矩阵。似乎我的乘法矩阵的量化过程是正确的。因此,我被引导相信结果本身的缩放是错误的。

这些是我的参考代码

这包含我们获取缩放值的方式 https://github.com/google/gemmlowp/blob/master/doc/quantization_example.cc

这包含基本矩阵乘法代码 https://github.com/google/gemmlowp/blob/master/test/test.cc#L35

#include <cstdio>
#include <utility>
#include <limits>
#include <cctype>
#include <iostream>
#include <random>
#include <iomanip>
#include <cassert>
#include <cmath>

template<class dtype>
class Mat{
public:
    const int rows, cols;
    dtype *array;

    Mat(int rows, int cols) : cols(cols), rows(rows){
        array = new dtype[cols * rows];
    }

    ~Mat(){
        delete[] array;
    }

    void randInit()
    {
        srand(time(NULL));

        for(int i = 0; i < rows; ++i)
            for(int j = 0; j < cols; ++j){
                float r = static_cast <float> (rand()) /
                    static_cast <float> (RAND_MAX);
                (*this)(i, j) = (dtype) (r * 2 - 1);
            }  
    }

    void shift_right(int shift_factor) const{
        for(int i = 0; i < rows * cols; i++)
            array[i] >>= shift_factor;
    }

    dtype& operator[](int idx) const{
        return array[idx];
    }

    dtype& operator()(int r, int c) const{
        return array[cols * r + c];
    }


    void mul(const dtype scalar){
        for(int size = rows * cols - 1; size >= 0; --size)
            array[size] *= scalar;
    }

    void add(const dtype scalar){
        for(int size = rows * cols - 1; size >= 0; --size)
            array[size] += scalar;
    }

    void sub(const dtype scalar){
        for(int size = rows * cols - 1; size >= 0; --size)
            array[size] -= scalar;
    }

    Mat operator*(const Mat& other){
        if(cols != other.rows){
            throw "Incorrect dimensions";
        }
        Mat result(rows, other.cols);

        for(int i = 0; i < rows; ++i){
            for(int j = 0; j < other.cols; ++j){
                for(int k = 0; k < cols; ++k){
                    result(i, j) += (*this)(i, k) * other(k, j);
                }
            }
        }
        return result;
    }

    std::pair<dtype, dtype> findMinMax(){
        std::pair <dtype, dtype> result(
            std::numeric_limits<dtype>::max(), 
            std::numeric_limits<dtype>::min()
        );

        for(int i = rows * cols - 1; i >= 0; --i)
        {
            if(result.first > array[i])
                result.first = array[i];

            if(result.second < array[i])
                result.second = array[i];
        }

        return result;
    }

    template<typename dtype2>
    Mat& operator=(const Mat<dtype2>& x){

        for(int i = rows * cols - 1; i >= 0; --i)
            array[i] = static_cast<dtype>(x[i]);

        return *this;
    }

    template<typename dtype2>
    Mat(const Mat<dtype2> &p2) : rows(p2.rows), cols(p2.cols)
    {   
        array = new dtype[rows * cols];
        *(this) = p2;
    }


};

template<class dtype>
std::ostream &operator<<(std::ostream &os, const Mat<dtype>& mat)
{ 
    for(int i = 0; i < mat.rows; ++i){
        for(int j = 0; j < mat.cols; ++j){
            os <<std::setw(8)<<std::setprecision(3)<< mat(i, j) <<'\t';
        }
        os << std::endl;
    }
    os<<std::endl;
    return os;
}

template<>
std::ostream &operator<<(std::ostream &os, const Mat<std::uint8_t>& mat)
{ 
    for(int i = 0; i < mat.rows; ++i){
        for(int j = 0; j < mat.cols; ++j){
            os <<std::setw(8)<<std::setprecision(3)<< int(mat(i, j)) <<'\t';
        }
        os << std::endl;
    }
    os<<std::endl;
    return os;
}

struct OutputStageQuantizeDownInt32ToUint8Scale
{
    std::int32_t result_offset;
    std::int32_t result_mult_int;
    std::int32_t result_shift;
};

struct QuantizationParams {
  float scale;
  std::uint8_t zero_point;
};

void QuantizeMultiplierSmallerThanOne(float real_multiplier,
                                      std::int32_t* quantized_multiplier,
                                      int* right_shift) {
  assert(real_multiplier > 0.f);
  assert(real_multiplier < 1.f);
  int s = 0;
  // We want to bring the real multiplier into the interval [1/2, 1).
  // We can do so by multiplying it by two, and recording how many times
  // we multiplied by two so that we can compensate that by a right
  // shift by the same amount.
  while (real_multiplier < 0.5f) {
    real_multiplier *= 2.0f;
    s++;
  }
  // Now that the real multiplier is in [1/2, 1), we convert it
  // into a fixed-point number.
  std::int64_t q =
      static_cast<std::int64_t>(std::round(real_multiplier * (1ll << 31)));
  assert(q <= (1ll << 31));
  // Handle the special case when the real multiplier was so close to 1
  // that its fixed-point approximation was undistinguishable from 1.
  // We handle this by dividing it by two, and remembering to decrement
  // the right shift amount.
  if (q == (1ll << 31)) {
    q /= 2;
    s--;
  }
  assert(s >= 0);
  assert(q <= std::numeric_limits<std::int32_t>::max());
  *quantized_multiplier = static_cast<std::int32_t>(q);
  *right_shift = s;
}

QuantizationParams ChooseQuantizationParams(float min, float max) {
    min = std::min(min, 0.f);
    max = std::max(max, 0.f);

    // the min and max quantized values, as floating-point values
    const float qmin = 0;
    const float qmax = 255;

    // First determine the scale.
    const double scale = (max - min) / (qmax - qmin);

    const double initial_zero_point = qmin - min / scale;

    std::uint8_t nudged_zero_point = 0;
    if (initial_zero_point < qmin) {
        nudged_zero_point = qmin;
    } else if (initial_zero_point > qmax) {
        nudged_zero_point = qmax;
    } else {
        nudged_zero_point =
            static_cast<std::uint8_t>(std::round(initial_zero_point));
    }

    QuantizationParams result;
    result.scale = scale;
    result.zero_point = nudged_zero_point;
    return result;
}

void Quantize(const QuantizationParams& qparams, const Mat<float>& src,
              Mat<std::uint8_t>& dst) {

  const size_t size = src.cols * src.rows;
  assert(size == dst.cols * dst.rows);

  for (std::size_t i = 0; i < size; i++) {
    const float real_val = src[i];
    const float transformed_val = qparams.zero_point + real_val / qparams.scale;
    const float clamped_val = std::max(0.f, std::min(255.f, transformed_val));
    dst[i] = static_cast<std::uint8_t>(std::round(clamped_val));
  }
}

void Dequantize(const QuantizationParams& qparams,
                const Mat<std::uint8_t>& src, 
                Mat<float>& dst) {
    const size_t size = src.cols * src.rows;
    assert(size == dst.cols * dst.rows);

    for (std::size_t i = 0; i < size; i++) {
        const std::uint8_t quantized_val = src[i];
        dst[i] = qparams.scale * (quantized_val - qparams.zero_point);
    }
}

int main()
{
    Mat<float> lhs(2, 4);   lhs.randInit();
    Mat<float> rhs(4, 3);   rhs.randInit();
    Mat<float> reference_result = lhs * rhs;

    Mat<std::uint8_t> uint8_lhs(2, 4);
    Mat<std::uint8_t> uint8_rhs(4, 3);

    auto lhs_min_max = lhs.findMinMax();
    auto rhs_min_max = rhs.findMinMax();
    auto result_min_max = reference_result.findMinMax();

    const auto lhs_qparams = ChooseQuantizationParams(
        lhs_min_max.first, lhs_min_max.second);
    const auto rhs_qparams = ChooseQuantizationParams(
        rhs_min_max.first, rhs_min_max.second);
    const auto result_qparams = ChooseQuantizationParams(
        result_min_max.first, result_min_max.second);

    Quantize(lhs_qparams, lhs, uint8_lhs);
    Quantize(rhs_qparams, rhs, uint8_rhs);


    const int lhs_offset = -lhs_qparams.zero_point;
    const int rhs_offset = -rhs_qparams.zero_point;
    const int result_offset = result_qparams.zero_point;

    std::cout<<"RESULT_OFFSET\t"<<result_offset<<std::endl;

    const float real_multiplier =
        lhs_qparams.scale * rhs_qparams.scale / result_qparams.scale;
    std::int32_t quantized_multiplier;
    int right_shift;
    QuantizeMultiplierSmallerThanOne(real_multiplier, &quantized_multiplier,
                                    &right_shift);


    std::cout << "End of OFFLINE QUANTIZATION CODE.\n" << std::endl;

    std::cout << "The below is ON-DEVICE RUNTIME QUANTIZED CODE. "
                << "This is the part that is performance-critical and may only "
                << "use quantized arithmetic.\n"
                << std::endl;


    OutputStageQuantizeDownInt32ToUint8Scale quantize_down_stage;
    quantize_down_stage.result_offset = result_offset;
    quantize_down_stage.result_mult_int = quantized_multiplier;
    quantize_down_stage.result_shift = right_shift;

    std::cout<<"QUANTIZED MULTIPLIER\t"<<quantized_multiplier<<std::endl;

    Mat<int> int_lhs = uint8_lhs;
    Mat<int> int_rhs = uint8_rhs;

    int_lhs.add(lhs_offset);
    int_rhs.add(rhs_offset);

    Mat<int> int_result = int_lhs * int_rhs;
    int_result.add(result_offset);
    int_result.mul(quantize_down_stage.result_mult_int);
    int_result.add(quantize_down_stage.result_shift < 1 ?
        0 : (1 << (quantize_down_stage.result_shift - 1)));
    int_result.shift_right(quantize_down_stage.result_shift);


    std::cout<<"INT32 RESULT"<<std::endl;
    std::cout<<int_result<<std::endl;

    return 0;
}

0 个答案:

没有答案