LoginSignup
8
7

More than 5 years have passed since last update.

多倍精度整数の基礎 ( Karatsuba乗算編 )

Last updated at Posted at 2014-06-02

概要

前回の四則演算編では古典的乗算の実装について触れた. 古典的乗算は基本的に, n 桁の多項式 $f_{n}, g_{n} \in \text{R}[x]$ から得られる積 $h_{2n} = f_{n} g_{n}$ を求めるのに指数時間 = $\text{O}(n^2)$ 回の乗算を必要とする. ここで $x$ は多項式の変数であり $\text{R}$ は乗算と加算において可換で単位元を持つ環 (Ring) とする. 多倍精度整数では $x$ は基数と見なされる.
以下の式は $n$ 次の積 $h_{2n} = f_{n} g_{n}$ を展開した図で 4 回の積が現れ, 更にそれら 4 つの乗算中でも再帰的な分割により, 同様に 4 回の乗算が必要な事が分かる.

f_n g_n = (A_n + B_{n/2})(C_n + D_{n/2}) = A_{2n} C_{2n} + A_n D_n + B_n C_n + B_{n/2} D_{n/2} = h_{2n}

Karatsuba 乗算ではこの $\text{O}(n^2)$ を, 準指数時間 $\text{O}(n^{\text{log}_2 3})$ で計算する.

アルゴリズム

積 $h_{2n} = f_n g_n$ は以下の式に書き変えられる.

f_n g_n = A_{2n} C_{2n} + B_{n/2} D_{n/2} + (A_n + B_n)(C_n + D_n) - A_n C_n - B_n D_n = h_{2n}

加減算は増えるものの, 乗算の回数は $A C$, $B D$, $(A + B)(C + D)$ の 3 回のみになる. そしてこの 3 つの乗算自体も 3 回の乗算に再帰的に処理してしまえば全体の乗算の回数は $\text{O}(n^{\text{log}_2 3})$ になる. これをKaratsuba乗算という.

計測

( 実装の詳細は次の節へ... )
16bit, 128 * (step num) 桁同士の値をそれぞれ古典的乗算 ( 上 ), Karatsuba 乗算 ( 下 ) とで速度比較を行った.
計測を行った環境は処理系 Microsoft Visual C++ 2013 と CPU Core i7-3770 3.40GHz.
速度比較に用いたコードは次の通り.

#include "integer.hpp"

#include <iostream>
#include <boost/timer.hpp>

using integer = multi_precision::integer<>;

void kar_time(std::size_t num){
    integer a, b, c, d;
    for(std::size_t i = 0; i < num; ++i){
        a.data.push_back(i + 1);
        b.data.push_back(num - i);
    }
    {
        boost::timer t;
        c = integer::classic_mul(a, b);
        double u = t.elapsed();
        std::cout << u << std::endl;
    }
    {
        boost::timer t;
        d = integer::mul(a, b);
        double u = t.elapsed();
        std::cout << u << std::endl;
    }
    if(c == d){
        std::cout << "success..." << std::endl;
    }else{
        std::cout << "fail..." << std::endl;
    }
}

int main(){
    for(std::size_t i = 0; i < 16; ++i){
        std::cout << "step " << (i + 1) << "--------" << std::endl;
        kar_time((i + 1) * 128);
        std::cout << std::endl;
    }
    return 0;
}
step 1--------
0.002
0
success...

step 2--------
0.013
0.001
success...

step 3--------
0.038
0.002
success...

step 4--------
0.089
0.002
success...

step 5--------
0.172
0.005
success...

step 6--------
0.295
0.005
success...

step 7--------
0.474
0.005
success...

step 8--------
0.704
0.004
success...

step 9--------
1.002
0.01
success...

step 10--------
1.369
0.01
success...

step 11--------
1.824
0.012
success...

step 12--------
2.363
0.01
success...

step 13--------
2.995
0.013
success...

step 14--------
3.753
0.012
success...

step 15--------
4.608
0.013
success...

step 16--------
5.589
0.008
success...

最終的な 128 * 16 = 2048 桁の乗算では古典的乗算が 5.5 秒以上掛かっているのに対し, Karatsuba乗算は 0.008 秒しか掛かっていない. アルゴリズム内部では $2^n$ 桁の数に対して処理に無駄が無く速いという特徴があるため, 直前の step 15 の 0.013 秒と比べるとしても, 430 倍近くの差が出ている.

準備

range による被演算対象の表現

工夫なしにコードに書き落とした場合, 変数の数が増えデータのコピーなどに時間が掛かり速度が落ちてしまう. これを回避するために変更のない範囲に対して range ベースで各処理を行う方法が有効だと考えられる.

struct kar_range{
    typename data_type::const_iterator first, last;
    std::size_t size;
};

2 の冪上への繰り上げ

Karatsuba 乗算では基本的に 2 の冪乗によって range で分割し処理行うため, 単位型に対して繰り上げを行う.
x が 0 になるまで 2 進数での桁数をカウントし, 途中 1 が現れたら n を加算する. n が 1 の時は, 値の最上位以外の bit は 0 のため入力された x を復元して return するのみで良い. n が 2 以上の時は, 値の最上位以外の bit が 1 つ以上 1 であるため繰り上がりをして return する.

static UInt ceil_pow2_single(UInt x){
    std::size_t n = 0, m = 0;
    while(x > 0){
        if((x & 1) > 0){
            ++n;
        }
        ++m;
        x >>= 1;
    }
    if(n == 1){
        return static_cast<UInt>(1 << (m - 1));
    }
    return static_cast<UInt>(1 << m);
}

基数シフト

多倍精度整数値に $\text{Radixs}^n$ を掛ける. これは 2 進数で馴染み深い, 左シフトに相応する. 多倍精度整数では先頭に 0 を n 個追加するのみで良い.

void elemental_shift(std::size_t n){
    data.insert(data.begin(), n, 0);
}

比較関数 kar_less ( range based )

範囲 rhs_first, rhs_last に対して *this < [rhs_first, rhs_last) の比較を行う. kar_sub ( range based ) で使用する.

bool kar_less(typename data_type::const_iterator &rhs_first, typename data_type::const_iterator &rhs_last) const{
    std::size_t lhs_n = data.size(), rhs_n = std::distance(rhs_first, rhs_last);
    typename data_type::const_iterator lhs_first = data.begin(), lhs_iter = lhs_first, rhs_iter = rhs_first;
    if(lhs_n < rhs_n){
        std::size_t n = rhs_n - lhs_n;
        for(std::size_t i = 0; i < n; ++i){
            if(*(rhs_iter + (rhs_n - i - 1)) > 0){
                return true;
            }
        }
        rhs_iter += lhs_n - 1;
        lhs_iter = lhs_first + (lhs_n - 1);
    }else if(lhs_n > rhs_n){
        std::size_t n = lhs_n - rhs_n;
        for(std::size_t i = 0; i < n; ++i){
            if(*(lhs_iter + (lhs_n - i - 1)) > 0){
                return false;
            }
            lhs_iter += rhs_n - 1;
            rhs_iter = rhs_iter + (rhs_n - 1);
        }
    }else{
        lhs_iter += lhs_n - 1;
        rhs_iter += rhs_n - 1;
    }
    for(; ; ){
        UInt l = *lhs_iter, r = *rhs_iter;
        if(l < r){ return true; }
        if(l > r){ return false; }
        if(lhs_iter == lhs_first){
            break;
        }
        --lhs_iter, --rhs_iter;
    }
    return false;
}

減算 kar_sub ( range based )

range ベースの減算処理.

void kar_sub(typename data_type::const_iterator &rhs_first, typename data_type::const_iterator &rhs_last){
    if(kar_less(rhs_first, rhs_last)){
        integer s(*this);
        data.assign(rhs_first, rhs_last);
        unsigned_sub(s);
        sign *=  - 1;
        normalize_data_size();
    }else{
        std::size_t n = std::distance(rhs_first, rhs_last);
        if(n > data.size()){
            unsigned_sub(rhs_first, rhs_first + data.size());
        }else{
            unsigned_sub(rhs_first, rhs_last);
        }
        normalize_data_size();
    }
}

加算 kar_add ( range based )

特に工夫する点はない.

void kar_add(typename data_type::const_iterator &rhs_first, typename data_type::const_iterator &rhs_last){
    std::size_t rhs_size = st::distance(rhs_first, rhs_last);
    if(data.resize() < rhs_size){
        data.resize(rhs_size);
    }
    typename data_type::iterator operand_iter = data.begin();
    DoubleUInt c = 0;
    typename data_type::const_iterator iter = rhs_first, end = rhs_last;
    --end;
    for(; iter != end; ++iter, ++operand_iter){
        UInt &operand(*operand_iter);
        DoubleUInt s = operand + *iter + c;
        operand = static_cast<UInt>(s & base_mask);
        c = s >> BitNum;
    }
    {
        UInt &operand(*operand_iter);
        DoubleUInt s = static_cast<DoubleUInt>(operand) + static_cast<DoubleUInt>(*iter) + c;
        operand = static_cast<UInt>(s & base_mask);
        c = s >> BitNum;
        if(c > 0){
            data.push_back(static_cast<UInt>(c & base_mask));
        }
    }
}

プリミティブな乗算 ( range based )

Karatsuba 乗算で扱えなくなった小さな値同士の乗算. こちらも range ベースなだけの通常の乗算になっている.

void kar_mul_range(const kar_range &lhs, const kar_range &rhs){
    data.resize(lhs.size + rhs.size - 1);
    if(data.empty()){
        normalize_data_size();
        return;
    }
    std::size_t n = 0;
    for(typename data_type::const_iterator rhs_iter = rhs.first; rhs_iter != rhs.last; ++rhs_iter, ++n){
        UInt rhs_value = *rhs_iter, lhs_value;
        std::size_t m = 0;
        for(typename data_type::const_iterator lhs_iter = lhs.first; lhs_iter != lhs.last; ++lhs_iter, ++m){
            lhs_value = *lhs_iter;
            DoubleUInt a = static_cast<DoubleUInt>(rhs_value) * static_cast<DoubleUInt>(lhs_value);
            unsigned_single_add(static_cast<UInt>(a & base_mask), n + m);
            UInt temp = static_cast<UInt>((a >> BitNum) & base_mask);
            if(temp > 0){
                unsigned_single_add(temp, n + m + 1);
            }
        }
    }
}

実装

range の構築と準備の節で書き出した kar_* 系の関数の呼び出しが主な処理となる. 内部で自分自身を 3 回呼び出している点が $\text{O}(n^{log_2 3})$ の根拠になっている.

void kar_rec_mul(const kar_range &x, const kar_range &y){
    std::size_t n = ceil_pow2_single((std::max)(x.size, y.size));
    if(n < 2){
        if(x.size > 0 && y.size > 0){
            kar_mul_range(x, y);
            normalize_data_size();
        }
        return;
    }
    n /= 2;
    std::size_t xn = x.size < n ? x.size : n, yn = y.size < n ? y.size : n;
    kar_range x0, y0, x1, y1;
    x0.first = x.first,       x0.last = x.first + xn,  x0.size = xn;
    x1.first = x.first + xn,  x1.last = x.last,        x1.size = x.size - xn;
    y0.first = y.first,       y0.last = y.first + yn,  y0.size = yn;
    y1.first = y.first + yn,  y1.last = y.last,        y1.size = y.size - yn;
    integer z2, z0;
    {
        integer tx, ty;
        if(x1.first != x1.last){
            tx.data.assign(x1.first, x1.last);
            tx.kar_sub(x0.first, x0.last);
        }else{
            tx.data.assign(x0.first, x0.last);
            tx.sign =  - 1;
        }
        if(y1.first != y1.last){
            ty.data.assign(y1.first, y1.last);
            ty.kar_sub(y0.first, y0.last);
        }else{
            ty.data.assign(y0.first, y0.last);
            ty.sign =  - 1;
        }
        kar_range rx, ry;
        rx.first = tx.data.begin(), rx.last = tx.data.end(), rx.size = tx.data.size();
        ry.first = ty.data.begin(), ry.last = ty.data.end(), ry.size = ty.data.size();
        kar_rec_mul(rx, ry);
        sign = tx.sign != ty.sign ? +1 :  - 1;
    }
    z0.kar_rec_mul(x0, y0);
    add(z0);
    if(x.size >= n && y.size >= n){
        z2.kar_rec_mul(x1, y1);
        add(z2);
    }
    elemental_shift(n);
    add(z0);
    if(x.size >= n && y.size >= n){
        add(z2, n * 2);
    }
    normalize_data_size();
}

全体像

#ifndef MULTI_PRECISION_INCLUDE_INTEGER_HPP
#define MULTI_PRECISION_INCLUDE_INTEGER_HPP

#include <algorithm>
#include <vector>
#include <iostream>
#include <string>
#include <iterator>
#include <utility>
#include <cstdint>
#include <climits>
#include <cmath>

namespace multi_precision{
    template<class UInt = std::uint16_t, class DoubleUInt = std::uint32_t, class DoubleInt = std::int32_t, DoubleUInt BitNum = sizeof(UInt) * CHAR_BIT>
    class integer{
    public:
        static const DoubleUInt base_mask = (static_cast<DoubleUInt>(1) << BitNum) - 1;
        static const UInt half = static_cast<UInt>(static_cast<DoubleUInt>(1) << (BitNum - 1));

        using data_type = std::vector<UInt>;
        data_type data;
        DoubleInt sign = +1;

        integer() = default;
        integer(const integer&) = default;
        integer(integer &&other) : data(std::move(other.data)), sign(other.sign){}
        integer(int x){
            if(x == 0){ return; }
            data.resize(1);
            data[0] = std::abs(x);
            sign = x >= 0 ? +1 :  - 1;
        }

        integer(UInt x){
            if(x == 0){ return; }
            data.resize(1);
            data[0] = x;
            sign = +1;
        }

        integer(const char *str){
            build_from_str(str);
        }

        integer(const std::string &str){
            build_from_str(str.begin());
        }

        template<class StrIter>
        void build_from_str(StrIter iter){
            if(*iter == '-'){
                sign =  - 1;
                ++iter;
            }
            char buff[2] = { 0 };
            while(*iter){
                buff[0] = *iter;
                int a = std::atoi(buff);
                *this *= 10;
                unsigned_single_add(a);
                ++iter;
            }
        }

        ~integer() = default;

        integer &operator =(const integer &other){
            data = other.data;
            sign = other.sign;
            return *this;
        }

        integer &operator =(integer &&other){
            data = std::move(other.data);
            sign = std::move(other.sign);
            return *this;
        }

        bool operator <(const integer &other) const{
            if(sign < other.sign){ return true; }
            if(sign > other.sign){ return false; }
            if(data.size() < other.data.size()){ return sign > 0; }
            if(data.size() > other.data.size()){ return sign < 0; }
            std::size_t i = data.size() - 1;
            do{
                if(data[i] < other.data[i]){ return true; }
            } while(i-- > 0);
            return false;
        }

        bool operator >(const integer &other) const{
            return other < *this;
        }

        bool operator <=(const integer &other) const{
            if(sign < other.sign){ return true; }
            if(sign > other.sign){ return false; }
            if(data.size() < other.data.size()){ return sign > 0; }
            if(data.size() > other.data.size()){ return sign < 0; }
            std::size_t i = data.size() - 1;
            do{
                if(data[i] > other.data[i]){ return false; }
            } while(i-- > 0);
            return true;
        }

        bool operator >=(const integer &other) const{
            return other <= *this;
        }

        bool operator ==(const integer &other) const{
            return sign == other.sign && data == other.data;
        }

        bool operator !=(const integer &other) const{
            return !(*this == other);
        }

        static bool unsigned_less(const data_type &lhs, const data_type &rhs){
            if(lhs.size() < rhs.size()){ return true; }
            if(lhs.size() > rhs.size()){ return false; }
            for(std::size_t i = lhs.size() - 1; i + 1 > 0; --i){
                if(lhs[i] < rhs[i]){ return true; }
                if(lhs[i] > rhs[i]){ return false; }
            }
            return false;
        }

        static bool unsigned_less_eq(const data_type &lhs, const data_type &rhs){
            if(lhs.size() < rhs.size()){ return true; }
            if(lhs.size() > rhs.size()){ return false; }
            for(std::size_t i = lhs.size() - 1; i + 1 > 0; --i){
                if(lhs[i] < rhs[i]){ return true; }
                if(lhs[i] > rhs[i]){ return false; }
            }
            return true;
        }

        void add(const integer &other, std::size_t n = 0){
            if(sign == other.sign){
                unsigned_add(other, n);
            }else{
                if(unsigned_less(data, other.data)){
                    integer temp(*this);
                    data = other.data;
                    unsigned_sub(temp);
                    sign = other.sign;
                }else{
                    unsigned_sub(other);
                }
            }
            normalize_data_size();
        }

        void sub(const integer &other, std::size_t n = 0){
            if(sign != other.sign){
                unsigned_add(other, n);
            }else{
                if(unsigned_less(data, other.data)){
                    integer temp(*this);
                    data = other.data;
                    unsigned_sub(temp);
                    sign = other.sign;
                }else{
                    unsigned_sub(other);
                }
            }
            normalize_data_size();
        }

        void unsigned_add(const integer &other, std::size_t n = 0){
            if(data.size() < other.data.size() + n){
                data.resize(other.data.size() + n);
            }
            DoubleUInt c = 0;
            std::size_t i;
            for(i = 0; i < other.data.size(); ++i){
                DoubleUInt v = static_cast<DoubleUInt>(data[i + n]) + static_cast<DoubleUInt>(other.data[i]) + c;
                data[i + n] = static_cast<UInt>(v & base_mask);
                c = v >> BitNum;
            }
            for(; c > 0; ++i){
                if(i >= data.size()){ data.resize(data.size() + 1); }
                DoubleUInt v = static_cast<DoubleUInt>(data[i + n]) + c;
                data[i + n] = static_cast<UInt>(v & base_mask);
                c = v >> BitNum;
            }
        }

        void unsigned_single_add(UInt c, std::size_t i = 0){
            if(i >= data.size()){ data.resize(i + 1); }
            for(; c > 0; ++i){
                DoubleUInt v = static_cast<DoubleUInt>(data[i]) + c;
                data[i] = static_cast<UInt>(v & base_mask);
                c = v >> BitNum;
            }
        }

        void unsigned_sub(const integer &other, std::size_t i = 0){
            unsigned_sub(other.data.begin(), other.data.end(), i);
        }

        void unsigned_sub(const typename data_type::const_iterator &other_first, const typename data_type::const_iterator &other_last, std::size_t i = 0){
            typename data_type::const_iterator iter = other_first;
            UInt c = 0;
            for(std::size_t n = std::distance(other_first, other_last); i < n; ++i, ++iter){
                const UInt &other_value(*iter);
                UInt t = data[i] - (other_value + c);
                if(data[i] < other_value + c){
                    c = 1;
                }else{
                    c = 0;
                }
                data[i] = t;
            }
            for(; c > 0; ++i){
                UInt t = data[i] - c;
                if(data[i] < c){
                    c = 1;
                }else{
                    c = 0;
                }
                data[i] = t;
            }
            normalize_data_size();
        }

        static void unsigned_mul(integer &r, const integer &lhs, const integer &rhs){
            std::size_t s = lhs.data.size() + rhs.data.size() - 1;
            r.data.resize(s + 1);
            for(std::size_t i = 0; i < lhs.data.size(); ++i){
                for(std::size_t j = 0; j < rhs.data.size(); ++j){
                    DoubleUInt c = static_cast<DoubleUInt>(lhs.data[i]) * static_cast<DoubleUInt>(rhs.data[j]);
                    for(std::size_t k = 0; i + j + k < s + 1; ++k){
                        std::size_t u = i + j + k;
                        DoubleUInt v = static_cast<DoubleUInt>(r.data[u]) + c;
                        UInt a = static_cast<UInt>(v & base_mask);
                        r.data[u] = a;
                        c = v >> BitNum;
                    }
                }
            }
            r.normalize_data_size();
        }
        static integer mul(const integer &lhs, const integer &rhs){
            integer r;
            r.kar_mul(lhs, rhs);
            return r;
        }

        static integer classic_mul(const integer &lhs, const integer &rhs){
            integer r;
            unsigned_mul(r, lhs, rhs);
            return r;
        }

        static void unsigned_single_mul(integer &r, const integer &lhs, UInt rhs){
            std::size_t s = lhs.data.size() + 1;
            r.data.resize(s + 1);
            for(std::size_t i = 0; i < lhs.data.size(); ++i){
                DoubleUInt c = static_cast<DoubleUInt>(lhs.data[i]) * static_cast<DoubleUInt>(rhs);
                for(std::size_t k = 0; i + k < s + 1; ++k){
                    std::size_t u = i + k;
                    DoubleUInt v = static_cast<DoubleUInt>(r.data[u]) + c;
                    UInt a = static_cast<UInt>(v & base_mask);
                    r.data[u] = a;
                    c = v >> BitNum;
                }
            }
            r.normalize_data_size();
        }

        struct kar_range{
            typename data_type::const_iterator first, last;
            std::size_t size;
        };

        void kar_mul(const integer &lhs, const integer &rhs){
            kar_range tx, ty;
            tx.first = lhs.data.begin(), tx.last = lhs.data.end(), tx.size = lhs.data.size();
            ty.first = rhs.data.begin(), ty.last = rhs.data.end(), ty.size = rhs.data.size();
            kar_rec_mul(tx, ty);
            sign = lhs.sign * rhs.sign;
        }


        void kar_mul_range(const kar_range &lhs, const kar_range &rhs){
            data.resize(lhs.size + rhs.size - 1);
            if(data.empty()){
                normalize_data_size();
                return;
            }
            std::size_t n = 0;
            for(typename data_type::const_iterator rhs_iter = rhs.first; rhs_iter != rhs.last; ++rhs_iter, ++n){
                UInt rhs_value = *rhs_iter, lhs_value;
                std::size_t m = 0;
                for(typename data_type::const_iterator lhs_iter = lhs.first; lhs_iter != lhs.last; ++lhs_iter, ++m){
                    lhs_value = *lhs_iter;
                    DoubleUInt a = static_cast<DoubleUInt>(rhs_value) * static_cast<DoubleUInt>(lhs_value);
                    unsigned_single_add(static_cast<UInt>(a & base_mask), n + m);
                    UInt temp = static_cast<UInt>((a >> BitNum) & base_mask);
                    if(temp > 0){
                        unsigned_single_add(temp, n + m + 1);
                    }
                }
            }
        }

        bool kar_less(typename data_type::const_iterator &rhs_first, typename data_type::const_iterator &rhs_last) const{
            std::size_t lhs_n = data.size(), rhs_n = std::distance(rhs_first, rhs_last);
            typename data_type::const_iterator lhs_first = data.begin(), lhs_iter = lhs_first, rhs_iter = rhs_first;
            if(lhs_n < rhs_n){
                std::size_t n = rhs_n - lhs_n;
                for(std::size_t i = 0; i < n; ++i){
                    if(*(rhs_iter + (rhs_n - i - 1)) > 0){
                        return true;
                    }
                }
                rhs_iter += lhs_n - 1;
                lhs_iter = lhs_first + (lhs_n - 1);
            }else if(lhs_n > rhs_n){
                std::size_t n = lhs_n - rhs_n;
                for(std::size_t i = 0; i < n; ++i){
                    if(*(lhs_iter + (lhs_n - i - 1)) > 0){
                        return false;
                    }
                    lhs_iter += rhs_n - 1;
                    rhs_iter = rhs_iter + (rhs_n - 1);
                }
            }else{
                lhs_iter += lhs_n - 1;
                rhs_iter += rhs_n - 1;
            }
            for(; ; ){
                UInt l = *lhs_iter, r = *rhs_iter;
                if(l < r){ return true; }
                if(l > r){ return false; }
                if(lhs_iter == lhs_first){
                    break;
                }
                --lhs_iter, --rhs_iter;
            }
            return false;
        }

        void kar_add(typename data_type::const_iterator &rhs_first, typename data_type::const_iterator &rhs_last){
            std::size_t rhs_size = st::distance(rhs_first, rhs_last);
            if(data.resize() < rhs_size){
                data.resize(rhs_size);
            }
            typename data_type::iterator operand_iter = data.begin();
            DoubleUInt c = 0;
            typename data_type::const_iterator iter = rhs_first, end = rhs_last;
            --end;
            for(; iter != end; ++iter, ++operand_iter){
                UInt &operand(*operand_iter);
                DoubleUInt s = operand + *iter + c;
                operand = static_cast<UInt>(s & base_mask);
                c = s >> BitNum;
            }
            {
                UInt &operand(*operand_iter);
                DoubleUInt s = static_cast<DoubleUInt>(operand) + static_cast<DoubleUInt>(*iter) + c;
                operand = static_cast<UInt>(s & base_mask);
                c = s >> BitNum;
                if(c > 0){
                    data.push_back(static_cast<UInt>(c & base_mask));
                }
            }
        }

        void kar_sub(typename data_type::const_iterator &rhs_first, typename data_type::const_iterator &rhs_last){
            if(kar_less(rhs_first, rhs_last)){
                integer s(*this);
                data.assign(rhs_first, rhs_last);
                unsigned_sub(s);
                sign *=  - 1;
                normalize_data_size();
            }else{
                std::size_t n = std::distance(rhs_first, rhs_last);
                if(n > data.size()){
                    unsigned_sub(rhs_first, rhs_first + data.size());
                }else{
                    unsigned_sub(rhs_first, rhs_last);
                }
                normalize_data_size();
            }
        }

        void kar_rec_mul(const kar_range &x, const kar_range &y){
            std::size_t n = ceil_pow2_single((std::max)(x.size, y.size));
            if(n < 2){
                if(x.size > 0 && y.size > 0){
                    kar_mul_range(x, y);
                    normalize_data_size();
                }
                return;
            }
            n /= 2;
            std::size_t xn = x.size < n ? x.size : n, yn = y.size < n ? y.size : n;
            kar_range x0, y0, x1, y1;
            x0.first = x.first,       x0.last = x.first + xn,  x0.size = xn;
            x1.first = x.first + xn,  x1.last = x.last,        x1.size = x.size - xn;
            y0.first = y.first,       y0.last = y.first + yn,  y0.size = yn;
            y1.first = y.first + yn,  y1.last = y.last,        y1.size = y.size - yn;
            integer z2, z0;
            {
                integer tx, ty;
                if(x1.first != x1.last){
                    tx.data.assign(x1.first, x1.last);
                    tx.kar_sub(x0.first, x0.last);
                }else{
                    tx.data.assign(x0.first, x0.last);
                    tx.sign =  - 1;
                }
                if(y1.first != y1.last){
                    ty.data.assign(y1.first, y1.last);
                    ty.kar_sub(y0.first, y0.last);
                }else{
                    ty.data.assign(y0.first, y0.last);
                    ty.sign =  - 1;
                }
                kar_range rx, ry;
                rx.first = tx.data.begin(), rx.last = tx.data.end(), rx.size = tx.data.size();
                ry.first = ty.data.begin(), ry.last = ty.data.end(), ry.size = ty.data.size();
                kar_rec_mul(rx, ry);
                sign = tx.sign != ty.sign ? +1 :  - 1;
            }
            z0.kar_rec_mul(x0, y0);
            add(z0);
            if(x.size >= n && y.size >= n){
                z2.kar_rec_mul(x1, y1);
                add(z2);
            }
            elemental_shift(n);
            add(z0);
            if(x.size >= n && y.size >= n){
                add(z2, n * 2);
            }
            normalize_data_size();
        }

        struct quo_rem{
            integer quo, rem;

            quo_rem() = default;
            quo_rem(const quo_rem&) = default;
            quo_rem(quo_rem &&other) : quo(std::move(other.quo)), rem(std::move(other.rem)){}
            ~quo_rem() = default;
        };

        static quo_rem unsigned_div(const integer &lhs, const integer &rhs){
            quo_rem qr;
            qr.quo.data.reserve(lhs.data.size());
            qr.rem.data.reserve(rhs.data.size());
            qr.rem.data.push_back(0);
            for(std::size_t i = lhs.data.size() * static_cast<std::size_t>(BitNum) - 1; i + 1 > 0; --i){
                if(!qr.rem.data.empty()){
                    qr.rem <<= 1;
                }else{
                    qr.rem.data.push_back(0);
                }
                qr.rem.data[0] |= (lhs.data[i / static_cast<std::size_t>(BitNum)] >> (i % static_cast<std::size_t>(BitNum))) & 1;
                if(unsigned_less_eq(rhs.data, qr.rem.data)){
                    qr.rem.unsigned_sub(rhs);
                    std::size_t t = i / static_cast<std::size_t>(BitNum);
                    if(qr.quo.data.size() < t + 1){ qr.quo.data.resize(t + 1); }
                    qr.quo.data[t] |= 1 << (i % BitNum);
                }
            }
            return qr;
        }

        static quo_rem div(const integer &lhs, const integer &rhs){
            quo_rem qr = unsigned_div(lhs, rhs);
            qr.quo.sign = qr.rem.sign = lhs.sign == rhs.sign ? +1 :  - 1;
            return qr;
        }

        static UInt ceil_pow2_single(UInt x){
            std::size_t n = 0, m = 0;
            while(x > 0){
                if((x & 1) > 0){
                    ++n;
                }
                ++m;
                x >>= 1;
            }
            if(n == 1){
                return static_cast<UInt>(1 << (m - 1));
            }
            return static_cast<UInt>(1 << m);
        }

        void elemental_shift(std::size_t n){
            data.insert(data.begin(), n, 0);
        }

        std::size_t bit_num() const{
            return bit_num(data.back());
        }

        static std::size_t bit_num(UInt x){
            std::size_t n = 0;
            while(x > 0){
                ++n;
                x >>= 1;
            }
            return n;
        }

        std::size_t finite_bit_shift_lsr(std::size_t n){
            std::size_t
                digit = n / static_cast<std::size_t>(BitNum),
                shift = n % static_cast<std::size_t>(BitNum),
                rev_shift = static_cast<std::size_t>(BitNum) - shift;
            UInt c = 0;
            for(std::size_t i = 0; i < data.size(); ++i){
                UInt x = data[i];
                data[i] = (x << shift) | c;
                if(rev_shift < BitNum){
                    c = x >> rev_shift;
                }else{
                    c = 0;
                }
            }
            if(c > 0){ data.push_back(c); }
            return digit;
        }

        void bit_shift_lsr(std::size_t n){
            elemental_shift(finite_bit_shift_lsr(n));
        }

        void bit_shift_rsl(std::size_t n){
            std::size_t
                digit = n / static_cast<std::size_t>(BitNum),
                shift = n % static_cast<std::size_t>(BitNum),
                rev_shift = BitNum - shift,
                size = data.size();

            if(digit > 0){
                for(std::size_t i = 0, length = size - digit; i < length; ++i){
                    data[i] = data[i + digit];
                }
                for(std::size_t i = 0; i < digit; ++i){
                    data[size - i - 1] = 0;
                }
            }

            UInt c = 0;
            for(std::size_t i = 0; i < data.size(); ++i){
                std::size_t j = size - i - 1;
                UInt x = data[j];
                data[j] = (x >> shift) | c;
                if(rev_shift < BitNum){
                    x = x << rev_shift;
                }else{
                    c = 0;
                }
            }
            normalize_data_size();
        }

        void normalize_data_size(){
            if(data.empty()){
                sign = +1;
                return;
            }
            std::size_t i = data.size() - 1;
            for(; ; ){
                if(i == 0 && data.front() == 0){
                    data.clear();
                    sign = +1;
                    return;
                }
                if(data[i] == 0){
                    --i;
                }else{
                    break;
                }
            }
            data.resize(i + 1);
        }

#define MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(assign_op, op, type) \
    integer &operator assign_op(const type &rhs){ *this = *this op integer(rhs); return *this; }

        MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(+= , +, int);
        MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(-= , -, int);
        MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(*= , *, int);
        MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(/= , /, int);
        MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(%= , %, int);

        MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(+= , +, unsigned int);
        MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(-= , -, unsigned int);
        MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(*= , *, unsigned int);
        MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(/= , /, unsigned int);
        MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(%= , %, unsigned int);
    };

#define MULTI_PRECISION_COMPARE_OPERATOR(op, type) \
    template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum> \
    bool operator op(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const type &rhs){ \
        return lhs op integer(rhs); \
    } \
    template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum> \
    bool operator op(const type &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){ \
        return integer(lhs) op rhs; \
    }

    MULTI_PRECISION_COMPARE_OPERATOR(==, int);
    //MULTI_PRECISION_COMPARE_OPERATOR(!=, int);
    MULTI_PRECISION_COMPARE_OPERATOR(<=, int);
    MULTI_PRECISION_COMPARE_OPERATOR(>=, int);

    MULTI_PRECISION_COMPARE_OPERATOR(==, unsigned int);
    //MULTI_PRECISION_COMPARE_OPERATOR(!=,  other.sign){ return true; }
            if(sign unsigned int);
    MULTI_PRECISION_COMPARE_OPERATOR(<=, unsigned int);
    MULTI_PRECISION_COMPARE_OPERATOR(>=, unsigned int);

    template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
    integer<UInt, DoubleUInt, DoubleInt, BitNum> operator +(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){
        integer<UInt, DoubleUInt, DoubleInt, BitNum> t(lhs);
        t.add(rhs);
        return t;
    }

    template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
    integer<UInt, DoubleUInt, DoubleInt, BitNum> operator -(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){
        integer<UInt, DoubleUInt, DoubleInt, BitNum> t(lhs);
        t.sub(rhs);
        return t;
    }

    template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
    integer<UInt, DoubleUInt, DoubleInt, BitNum> operator *(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){
        return integer<UInt, DoubleUInt, DoubleInt, BitNum>::mul(lhs, rhs);
    }

    template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
    integer<UInt, DoubleUInt, DoubleInt, BitNum> operator /(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){
        return integer<UInt, DoubleUInt, DoubleInt, BitNum>::div(lhs, rhs).quo;
    }

    template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
    integer<UInt, DoubleUInt, DoubleInt, BitNum> operator %(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){
        return integer<UInt, DoubleUInt, DoubleInt, BitNum>::div(lhs, rhs).rem;
    }

    template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
    integer<UInt, DoubleUInt, DoubleInt, BitNum> operator <<(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, UInt n){
        integer<UInt, DoubleUInt, DoubleInt, BitNum> x(lhs);
        x.bit_shift_lsr(n);
        return x;
    }

    template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
    integer<UInt, DoubleUInt, DoubleInt, BitNum> operator >>(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, UInt n){
        integer<UInt, DoubleUInt, DoubleInt, BitNum> x(lhs);
        x.bit_shift_rsl(n);
        return x;
    }

    template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
    integer<UInt, DoubleUInt, DoubleInt, BitNum> operator +=(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){
        add(rhs);
        return *this;
    }

    template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
    integer<UInt, DoubleUInt, DoubleInt, BitNum> operator -=(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){
        sub(rhs);
        return *this;
    }

    template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
    integer<UInt, DoubleUInt, DoubleInt, BitNum> operator *=(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){
        *this = integer<UInt, DoubleUInt, DoubleInt, BitNum>::mul(lhs, rhs);
        return *this;
    }

    template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
    integer<UInt, DoubleUInt, DoubleInt, BitNum> &operator /=(integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){
        lhs = integer<UInt, DoubleUInt, DoubleInt, BitNum>::div(lhs, rhs).quo;
        return lhs;
    }

    template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
    integer<UInt, DoubleUInt, DoubleInt, BitNum> operator %=(integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){
        lhs = integer<UInt, DoubleUInt, DoubleInt, BitNum>::div(lhs, rhs).rem;
        return lhs;
    }

    template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
    integer<UInt, DoubleUInt, DoubleInt, BitNum> operator <<=(integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, std::size_t n){
        integer<UInt, DoubleUInt, DoubleInt, BitNum> x(lhs);
        x.bit_shift_lsr(n);
        lhs = std::move(x);
        return lhs;
    }

    template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
    integer<UInt, DoubleUInt, DoubleInt, BitNum> operator >>=(integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, std::size_t n){
        integer<UInt, DoubleUInt, DoubleInt, BitNum> x(lhs);
        x.bit_shift_rsl(n);
        lhs = std::move(x);
        return lhs;
    }

#define MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(op, type) \
    template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum> \
    integer<UInt, DoubleUInt, DoubleInt, BitNum> operator op(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const type &rhs){ return lhs op integer<UInt, DoubleUInt, DoubleInt, BitNum>(rhs); } \
    template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum> \
    integer<UInt, DoubleUInt, DoubleInt, BitNum> operator op(const type &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){ return integer<UInt, DoubleUInt, DoubleInt, BitNum>(lhs) op rhs; }

    MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(+, int);
    MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(-, int);
    MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(*, int);
    MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(/ , int);
    MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(%, int);

    MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(+, unsigned int);
    MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(-, unsigned int);
    MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(*, unsigned int);
    MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(/ , unsigned int);
    MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(%, unsigned int);

    template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
    std::ostream &operator <<(
        std::ostream &os,
        integer<UInt, DoubleUInt, DoubleInt, BitNum> rhs
    ){
        using integer = integer<UInt, DoubleUInt, DoubleInt, BitNum>;
        std::vector<std::string> rseq;
        integer lo(10);
        for(; rhs.data.size() > 0; rhs /= lo){
            integer temp = rhs % lo;
            if(temp != 0){
                rseq.push_back(std::to_string(temp.data[0]));
            }else{
                rseq.push_back(std::to_string(0));
            }
        }
        os << (rhs.sign > 0 ? "" : "-");
        if(!rseq.empty()){
            for(auto iter = rseq.rbegin(); iter != rseq.rend(); ++iter){
                os << *iter;
            }
        }else{
            os << "0";
        }
        return os;
    }
}

#endif // MULTI_PRECISION_INCLUDE_INTEGER_HPP

8
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
7