使用向量的c ++ Karatsuba乘法

时间:2017-08-13 11:39:40

标签: c++ vector karatsuba

所以我一直在努力为Karatsuba乘法算法编写一个算法,并且我一直在尝试使用向量作为我的数据结构来处理将被输入的真正长数...

我的程序可以很小的数字,但它确实与更大的数字斗争,我得到核心转储(Seg Fault)。当左侧数字小于右侧时,它也会输出奇怪的结果。

有什么想法吗?这是代码。

#include <iostream>
#include <string>
#include <vector>

#define max(a,b) ((a) > (b) ? (a) : (b))

using namespace std;

vector<int> add(vector<int> lhs, vector<int> rhs) {
    int length = max(lhs.size(), rhs.size());
    int carry = 0;
    int sum_col;
    vector<int> result;

    while(lhs.size() < length) {
        lhs.insert(lhs.begin(), 0);
    }
    while(rhs.size() < length) {
        rhs.insert(rhs.begin(), 0);
    }

    for(int i = length-1; i >= 0; i--) {
        sum_col = lhs[i] + rhs[i] + carry;
        carry = sum_col/10;
        result.insert(result.begin(), (sum_col%10));
    }
    if(carry) {
        result.insert(result.begin(), carry);
    }
    int x = 0;
    while(result[x] == 0) {
        x++;
    }
    result.erase(result.begin(), result.begin()+x);
    return result;
}

vector<int> subtract(vector<int> lhs, vector<int> rhs) {
    int length = max(lhs.size(), rhs.size());
    int diff;
    vector<int> result;

    while(lhs.size() < length) {
        lhs.insert(lhs.begin(), 0);
    }
    while(rhs.size() < length) {
        rhs.insert(rhs.begin(), 0);
    }

    for(int i = length-1; i >= 0; i--) {
        diff = lhs[i] - rhs[i];
        if(diff >= 0) {
            result.insert(result.begin(), diff);
        } else {
            int j = i - 1;
            while(j >= 0) {
                lhs[j] = (lhs[j] - 1) % 10;
                if(lhs[j] != 9) {
                    break;
                } else {
                    j--;
                }
            }
            result.insert(result.begin(), diff+10);
        }
    }
    int x = 0;
    while(result[x] == 0) {
        x++;
    }
    result.erase(result.begin(), result.begin()+x);
    return result;
}

vector<int> multiply(vector<int> lhs, vector<int> rhs) {
    int length = max(lhs.size(), rhs.size());
    vector<int> result;

    while(lhs.size() < length) {
        lhs.insert(lhs.begin(), 0);
    }
    while(rhs.size() < length) {
        rhs.insert(rhs.begin(), 0);
    }

    if(length == 1) {
        int res = lhs[0]*rhs[0];
        if(res >= 10) {
            result.push_back(res/10);
            result.push_back(res%10);
            return result;
        } else {
            result.push_back(res);
            return result;
        }
    }

    vector<int>::const_iterator first0 = lhs.begin();
    vector<int>::const_iterator last0 = lhs.begin() + (length/2);
    vector<int> lhs0(first0, last0);
    vector<int>::const_iterator first1 = lhs.begin() + (length/2);
    vector<int>::const_iterator last1 = lhs.begin() + ((length/2) + (length-length/2));
    vector<int> lhs1(first1, last1);
    vector<int>::const_iterator first2 = rhs.begin();
    vector<int>::const_iterator last2 = rhs.begin() + (length/2);
    vector<int> rhs0(first2, last2);
    vector<int>::const_iterator first3 = rhs.begin() + (length/2);
    vector<int>::const_iterator last3 = rhs.begin() + ((length/2) + (length-length/2));
    vector<int> rhs1(first3, last3);

    vector<int> p0 = multiply(lhs0, rhs0);
    vector<int> p1 = multiply(lhs1,rhs1);
    vector<int> p2 = multiply(add(lhs0,lhs1),add(rhs0,rhs1));
    vector<int> p3 = subtract(p2,add(p0,p1));

    for(int i = 0; i < 2*(length-length/2); i++) {
        p0.push_back(0);
    }
    for(int i = 0; i < (length-length/2); i++) {
        p3.push_back(0);
    }

    result = add(add(p0,p1), p3);

    int x = 0;
    while(result[x] == 0) {
        x++;
    }
    result.erase(result.begin(), result.begin()+x);
    return result;
}

int main() {
    vector<int> lhs;
    vector<int> rhs;
    vector<int> v;

    lhs.push_back(2);
    lhs.push_back(5);
    lhs.push_back(2);
    lhs.push_back(5);
    lhs.push_back(2);
    lhs.push_back(5);
    lhs.push_back(2);
    lhs.push_back(5);

    rhs.push_back(1);
    rhs.push_back(5);
    rhs.push_back(1);
    rhs.push_back(5);
    rhs.push_back(1);
    rhs.push_back(5);
    rhs.push_back(1);


    v = multiply(lhs, rhs);

    for(size_t i = 0; i < v.size(); i++) {
        cout << v[i];
    }
    cout << endl;

    return 0;
    }

1 个答案:

答案 0 :(得分:1)

subtract存在几个问题。由于您无法表示负数,如果rhs大于lhs,您的借用逻辑将在lhs的数据开头之前访问。

如果结果为0,则在删除前导零时也可以超过result的末尾。

您的借用计算错误,因为-1 % 10将返回-1,而不是9,如果lhs[j]为0.更好的计算方法是添加9(比您的值小1)重新划分,lhs[j] = (lhs[j] + 9) % 10;

在不相关的注释中,您可以简化范围迭代计算。由于last0first1具有相同的值,因此您可以对{8}使用last0last1lhs.end()。这简化了lhs1

vector<int> lhs1(last0, lhs.end());

你可以摆脱first1last1rhs迭代器也是如此。