概要
前回の四則演算編では古典的乗算の実装について触れた. 古典的乗算は基本的に, n 桁の多項式 $f_{n}, g_{n} \in \text{R}[x]$ から得られる積 $h_{2n} = f_{n} g_{n}$ を求めるのに指数時間 = $\text{O}(n^2)$ 回の乗算を必要とする. ここで $x$ は多項式の変数であり $\text{R}$ は乗算と加算において可換で単位元を持つ環 (Ring) とする. 多倍精度整数では $x$ は基数と見なされる.
以下の式は $n$ 次の積 $h_{2n} = f_{n} g_{n}$ を展開した図で 4 回の積が現れ, 更にそれら 4 つの乗算中でも再帰的な分割により, 同様に 4 回の乗算が必要な事が分かる.
f_n g_n = (A_n + B_{n/2})(C_n + D_{n/2}) = A_{2n} C_{2n} + A_n D_n + B_n C_n + B_{n/2} D_{n/2} = h_{2n}
Karatsuba 乗算ではこの $\text{O}(n^2)$ を, 準指数時間 $\text{O}(n^{\text{log}_2 3})$ で計算する.
アルゴリズム
積 $h_{2n} = f_n g_n$ は以下の式に書き変えられる.
f_n g_n = A_{2n} C_{2n} + B_{n/2} D_{n/2} + (A_n + B_n)(C_n + D_n) - A_n C_n - B_n D_n = h_{2n}
加減算は増えるものの, 乗算の回数は $A C$, $B D$, $(A + B)(C + D)$ の 3 回のみになる. そしてこの 3 つの乗算自体も 3 回の乗算に再帰的に処理してしまえば全体の乗算の回数は $\text{O}(n^{\text{log}_2 3})$ になる. これをKaratsuba乗算という.
計測
( 実装の詳細は次の節へ... )
16bit, 128 * (step num) 桁同士の値をそれぞれ古典的乗算 ( 上 ), Karatsuba 乗算 ( 下 ) とで速度比較を行った.
計測を行った環境は処理系 Microsoft Visual C++ 2013 と CPU Core i7-3770 3.40GHz.
速度比較に用いたコードは次の通り.
#include "integer.hpp"
#include <iostream>
#include <boost/timer.hpp>
using integer = multi_precision::integer<>;
void kar_time(std::size_t num){
integer a, b, c, d;
for(std::size_t i = 0; i < num; ++i){
a.data.push_back(i + 1);
b.data.push_back(num - i);
}
{
boost::timer t;
c = integer::classic_mul(a, b);
double u = t.elapsed();
std::cout << u << std::endl;
}
{
boost::timer t;
d = integer::mul(a, b);
double u = t.elapsed();
std::cout << u << std::endl;
}
if(c == d){
std::cout << "success..." << std::endl;
}else{
std::cout << "fail..." << std::endl;
}
}
int main(){
for(std::size_t i = 0; i < 16; ++i){
std::cout << "step " << (i + 1) << "--------" << std::endl;
kar_time((i + 1) * 128);
std::cout << std::endl;
}
return 0;
}
step 1--------
0.002
0
success...
step 2--------
0.013
0.001
success...
step 3--------
0.038
0.002
success...
step 4--------
0.089
0.002
success...
step 5--------
0.172
0.005
success...
step 6--------
0.295
0.005
success...
step 7--------
0.474
0.005
success...
step 8--------
0.704
0.004
success...
step 9--------
1.002
0.01
success...
step 10--------
1.369
0.01
success...
step 11--------
1.824
0.012
success...
step 12--------
2.363
0.01
success...
step 13--------
2.995
0.013
success...
step 14--------
3.753
0.012
success...
step 15--------
4.608
0.013
success...
step 16--------
5.589
0.008
success...
最終的な 128 * 16 = 2048 桁の乗算では古典的乗算が 5.5 秒以上掛かっているのに対し, Karatsuba乗算は 0.008 秒しか掛かっていない. アルゴリズム内部では $2^n$ 桁の数に対して処理に無駄が無く速いという特徴があるため, 直前の step 15 の 0.013 秒と比べるとしても, 430 倍近くの差が出ている.
準備
range による被演算対象の表現
工夫なしにコードに書き落とした場合, 変数の数が増えデータのコピーなどに時間が掛かり速度が落ちてしまう. これを回避するために変更のない範囲に対して range ベースで各処理を行う方法が有効だと考えられる.
struct kar_range{
typename data_type::const_iterator first, last;
std::size_t size;
};
2 の冪上への繰り上げ
Karatsuba 乗算では基本的に 2 の冪乗によって range で分割し処理行うため, 単位型に対して繰り上げを行う.
x が 0 になるまで 2 進数での桁数をカウントし, 途中 1 が現れたら n を加算する. n が 1 の時は, 値の最上位以外の bit は 0 のため入力された x を復元して return するのみで良い. n が 2 以上の時は, 値の最上位以外の bit が 1 つ以上 1 であるため繰り上がりをして return する.
static UInt ceil_pow2_single(UInt x){
std::size_t n = 0, m = 0;
while(x > 0){
if((x & 1) > 0){
++n;
}
++m;
x >>= 1;
}
if(n == 1){
return static_cast<UInt>(1 << (m - 1));
}
return static_cast<UInt>(1 << m);
}
基数シフト
多倍精度整数値に $\text{Radixs}^n$ を掛ける. これは 2 進数で馴染み深い, 左シフトに相応する. 多倍精度整数では先頭に 0 を n 個追加するのみで良い.
void elemental_shift(std::size_t n){
data.insert(data.begin(), n, 0);
}
比較関数 kar_less ( range based )
範囲 rhs_first
, rhs_last
に対して *this < [rhs_first, rhs_last)
の比較を行う. kar_sub ( range based ) で使用する.
bool kar_less(typename data_type::const_iterator &rhs_first, typename data_type::const_iterator &rhs_last) const{
std::size_t lhs_n = data.size(), rhs_n = std::distance(rhs_first, rhs_last);
typename data_type::const_iterator lhs_first = data.begin(), lhs_iter = lhs_first, rhs_iter = rhs_first;
if(lhs_n < rhs_n){
std::size_t n = rhs_n - lhs_n;
for(std::size_t i = 0; i < n; ++i){
if(*(rhs_iter + (rhs_n - i - 1)) > 0){
return true;
}
}
rhs_iter += lhs_n - 1;
lhs_iter = lhs_first + (lhs_n - 1);
}else if(lhs_n > rhs_n){
std::size_t n = lhs_n - rhs_n;
for(std::size_t i = 0; i < n; ++i){
if(*(lhs_iter + (lhs_n - i - 1)) > 0){
return false;
}
lhs_iter += rhs_n - 1;
rhs_iter = rhs_iter + (rhs_n - 1);
}
}else{
lhs_iter += lhs_n - 1;
rhs_iter += rhs_n - 1;
}
for(; ; ){
UInt l = *lhs_iter, r = *rhs_iter;
if(l < r){ return true; }
if(l > r){ return false; }
if(lhs_iter == lhs_first){
break;
}
--lhs_iter, --rhs_iter;
}
return false;
}
減算 kar_sub ( range based )
range ベースの減算処理.
void kar_sub(typename data_type::const_iterator &rhs_first, typename data_type::const_iterator &rhs_last){
if(kar_less(rhs_first, rhs_last)){
integer s(*this);
data.assign(rhs_first, rhs_last);
unsigned_sub(s);
sign *= - 1;
normalize_data_size();
}else{
std::size_t n = std::distance(rhs_first, rhs_last);
if(n > data.size()){
unsigned_sub(rhs_first, rhs_first + data.size());
}else{
unsigned_sub(rhs_first, rhs_last);
}
normalize_data_size();
}
}
加算 kar_add ( range based )
特に工夫する点はない.
void kar_add(typename data_type::const_iterator &rhs_first, typename data_type::const_iterator &rhs_last){
std::size_t rhs_size = st::distance(rhs_first, rhs_last);
if(data.resize() < rhs_size){
data.resize(rhs_size);
}
typename data_type::iterator operand_iter = data.begin();
DoubleUInt c = 0;
typename data_type::const_iterator iter = rhs_first, end = rhs_last;
--end;
for(; iter != end; ++iter, ++operand_iter){
UInt &operand(*operand_iter);
DoubleUInt s = operand + *iter + c;
operand = static_cast<UInt>(s & base_mask);
c = s >> BitNum;
}
{
UInt &operand(*operand_iter);
DoubleUInt s = static_cast<DoubleUInt>(operand) + static_cast<DoubleUInt>(*iter) + c;
operand = static_cast<UInt>(s & base_mask);
c = s >> BitNum;
if(c > 0){
data.push_back(static_cast<UInt>(c & base_mask));
}
}
}
プリミティブな乗算 ( range based )
Karatsuba 乗算で扱えなくなった小さな値同士の乗算. こちらも range ベースなだけの通常の乗算になっている.
void kar_mul_range(const kar_range &lhs, const kar_range &rhs){
data.resize(lhs.size + rhs.size - 1);
if(data.empty()){
normalize_data_size();
return;
}
std::size_t n = 0;
for(typename data_type::const_iterator rhs_iter = rhs.first; rhs_iter != rhs.last; ++rhs_iter, ++n){
UInt rhs_value = *rhs_iter, lhs_value;
std::size_t m = 0;
for(typename data_type::const_iterator lhs_iter = lhs.first; lhs_iter != lhs.last; ++lhs_iter, ++m){
lhs_value = *lhs_iter;
DoubleUInt a = static_cast<DoubleUInt>(rhs_value) * static_cast<DoubleUInt>(lhs_value);
unsigned_single_add(static_cast<UInt>(a & base_mask), n + m);
UInt temp = static_cast<UInt>((a >> BitNum) & base_mask);
if(temp > 0){
unsigned_single_add(temp, n + m + 1);
}
}
}
}
実装
range の構築と準備の節で書き出した kar_*
系の関数の呼び出しが主な処理となる. 内部で自分自身を 3 回呼び出している点が $\text{O}(n^{log_2 3})$ の根拠になっている.
void kar_rec_mul(const kar_range &x, const kar_range &y){
std::size_t n = ceil_pow2_single((std::max)(x.size, y.size));
if(n < 2){
if(x.size > 0 && y.size > 0){
kar_mul_range(x, y);
normalize_data_size();
}
return;
}
n /= 2;
std::size_t xn = x.size < n ? x.size : n, yn = y.size < n ? y.size : n;
kar_range x0, y0, x1, y1;
x0.first = x.first, x0.last = x.first + xn, x0.size = xn;
x1.first = x.first + xn, x1.last = x.last, x1.size = x.size - xn;
y0.first = y.first, y0.last = y.first + yn, y0.size = yn;
y1.first = y.first + yn, y1.last = y.last, y1.size = y.size - yn;
integer z2, z0;
{
integer tx, ty;
if(x1.first != x1.last){
tx.data.assign(x1.first, x1.last);
tx.kar_sub(x0.first, x0.last);
}else{
tx.data.assign(x0.first, x0.last);
tx.sign = - 1;
}
if(y1.first != y1.last){
ty.data.assign(y1.first, y1.last);
ty.kar_sub(y0.first, y0.last);
}else{
ty.data.assign(y0.first, y0.last);
ty.sign = - 1;
}
kar_range rx, ry;
rx.first = tx.data.begin(), rx.last = tx.data.end(), rx.size = tx.data.size();
ry.first = ty.data.begin(), ry.last = ty.data.end(), ry.size = ty.data.size();
kar_rec_mul(rx, ry);
sign = tx.sign != ty.sign ? +1 : - 1;
}
z0.kar_rec_mul(x0, y0);
add(z0);
if(x.size >= n && y.size >= n){
z2.kar_rec_mul(x1, y1);
add(z2);
}
elemental_shift(n);
add(z0);
if(x.size >= n && y.size >= n){
add(z2, n * 2);
}
normalize_data_size();
}
全体像
#ifndef MULTI_PRECISION_INCLUDE_INTEGER_HPP
#define MULTI_PRECISION_INCLUDE_INTEGER_HPP
#include <algorithm>
#include <vector>
#include <iostream>
#include <string>
#include <iterator>
#include <utility>
#include <cstdint>
#include <climits>
#include <cmath>
namespace multi_precision{
template<class UInt = std::uint16_t, class DoubleUInt = std::uint32_t, class DoubleInt = std::int32_t, DoubleUInt BitNum = sizeof(UInt) * CHAR_BIT>
class integer{
public:
static const DoubleUInt base_mask = (static_cast<DoubleUInt>(1) << BitNum) - 1;
static const UInt half = static_cast<UInt>(static_cast<DoubleUInt>(1) << (BitNum - 1));
using data_type = std::vector<UInt>;
data_type data;
DoubleInt sign = +1;
integer() = default;
integer(const integer&) = default;
integer(integer &&other) : data(std::move(other.data)), sign(other.sign){}
integer(int x){
if(x == 0){ return; }
data.resize(1);
data[0] = std::abs(x);
sign = x >= 0 ? +1 : - 1;
}
integer(UInt x){
if(x == 0){ return; }
data.resize(1);
data[0] = x;
sign = +1;
}
integer(const char *str){
build_from_str(str);
}
integer(const std::string &str){
build_from_str(str.begin());
}
template<class StrIter>
void build_from_str(StrIter iter){
if(*iter == '-'){
sign = - 1;
++iter;
}
char buff[2] = { 0 };
while(*iter){
buff[0] = *iter;
int a = std::atoi(buff);
*this *= 10;
unsigned_single_add(a);
++iter;
}
}
~integer() = default;
integer &operator =(const integer &other){
data = other.data;
sign = other.sign;
return *this;
}
integer &operator =(integer &&other){
data = std::move(other.data);
sign = std::move(other.sign);
return *this;
}
bool operator <(const integer &other) const{
if(sign < other.sign){ return true; }
if(sign > other.sign){ return false; }
if(data.size() < other.data.size()){ return sign > 0; }
if(data.size() > other.data.size()){ return sign < 0; }
std::size_t i = data.size() - 1;
do{
if(data[i] < other.data[i]){ return true; }
} while(i-- > 0);
return false;
}
bool operator >(const integer &other) const{
return other < *this;
}
bool operator <=(const integer &other) const{
if(sign < other.sign){ return true; }
if(sign > other.sign){ return false; }
if(data.size() < other.data.size()){ return sign > 0; }
if(data.size() > other.data.size()){ return sign < 0; }
std::size_t i = data.size() - 1;
do{
if(data[i] > other.data[i]){ return false; }
} while(i-- > 0);
return true;
}
bool operator >=(const integer &other) const{
return other <= *this;
}
bool operator ==(const integer &other) const{
return sign == other.sign && data == other.data;
}
bool operator !=(const integer &other) const{
return !(*this == other);
}
static bool unsigned_less(const data_type &lhs, const data_type &rhs){
if(lhs.size() < rhs.size()){ return true; }
if(lhs.size() > rhs.size()){ return false; }
for(std::size_t i = lhs.size() - 1; i + 1 > 0; --i){
if(lhs[i] < rhs[i]){ return true; }
if(lhs[i] > rhs[i]){ return false; }
}
return false;
}
static bool unsigned_less_eq(const data_type &lhs, const data_type &rhs){
if(lhs.size() < rhs.size()){ return true; }
if(lhs.size() > rhs.size()){ return false; }
for(std::size_t i = lhs.size() - 1; i + 1 > 0; --i){
if(lhs[i] < rhs[i]){ return true; }
if(lhs[i] > rhs[i]){ return false; }
}
return true;
}
void add(const integer &other, std::size_t n = 0){
if(sign == other.sign){
unsigned_add(other, n);
}else{
if(unsigned_less(data, other.data)){
integer temp(*this);
data = other.data;
unsigned_sub(temp);
sign = other.sign;
}else{
unsigned_sub(other);
}
}
normalize_data_size();
}
void sub(const integer &other, std::size_t n = 0){
if(sign != other.sign){
unsigned_add(other, n);
}else{
if(unsigned_less(data, other.data)){
integer temp(*this);
data = other.data;
unsigned_sub(temp);
sign = other.sign;
}else{
unsigned_sub(other);
}
}
normalize_data_size();
}
void unsigned_add(const integer &other, std::size_t n = 0){
if(data.size() < other.data.size() + n){
data.resize(other.data.size() + n);
}
DoubleUInt c = 0;
std::size_t i;
for(i = 0; i < other.data.size(); ++i){
DoubleUInt v = static_cast<DoubleUInt>(data[i + n]) + static_cast<DoubleUInt>(other.data[i]) + c;
data[i + n] = static_cast<UInt>(v & base_mask);
c = v >> BitNum;
}
for(; c > 0; ++i){
if(i >= data.size()){ data.resize(data.size() + 1); }
DoubleUInt v = static_cast<DoubleUInt>(data[i + n]) + c;
data[i + n] = static_cast<UInt>(v & base_mask);
c = v >> BitNum;
}
}
void unsigned_single_add(UInt c, std::size_t i = 0){
if(i >= data.size()){ data.resize(i + 1); }
for(; c > 0; ++i){
DoubleUInt v = static_cast<DoubleUInt>(data[i]) + c;
data[i] = static_cast<UInt>(v & base_mask);
c = v >> BitNum;
}
}
void unsigned_sub(const integer &other, std::size_t i = 0){
unsigned_sub(other.data.begin(), other.data.end(), i);
}
void unsigned_sub(const typename data_type::const_iterator &other_first, const typename data_type::const_iterator &other_last, std::size_t i = 0){
typename data_type::const_iterator iter = other_first;
UInt c = 0;
for(std::size_t n = std::distance(other_first, other_last); i < n; ++i, ++iter){
const UInt &other_value(*iter);
UInt t = data[i] - (other_value + c);
if(data[i] < other_value + c){
c = 1;
}else{
c = 0;
}
data[i] = t;
}
for(; c > 0; ++i){
UInt t = data[i] - c;
if(data[i] < c){
c = 1;
}else{
c = 0;
}
data[i] = t;
}
normalize_data_size();
}
static void unsigned_mul(integer &r, const integer &lhs, const integer &rhs){
std::size_t s = lhs.data.size() + rhs.data.size() - 1;
r.data.resize(s + 1);
for(std::size_t i = 0; i < lhs.data.size(); ++i){
for(std::size_t j = 0; j < rhs.data.size(); ++j){
DoubleUInt c = static_cast<DoubleUInt>(lhs.data[i]) * static_cast<DoubleUInt>(rhs.data[j]);
for(std::size_t k = 0; i + j + k < s + 1; ++k){
std::size_t u = i + j + k;
DoubleUInt v = static_cast<DoubleUInt>(r.data[u]) + c;
UInt a = static_cast<UInt>(v & base_mask);
r.data[u] = a;
c = v >> BitNum;
}
}
}
r.normalize_data_size();
}
static integer mul(const integer &lhs, const integer &rhs){
integer r;
r.kar_mul(lhs, rhs);
return r;
}
static integer classic_mul(const integer &lhs, const integer &rhs){
integer r;
unsigned_mul(r, lhs, rhs);
return r;
}
static void unsigned_single_mul(integer &r, const integer &lhs, UInt rhs){
std::size_t s = lhs.data.size() + 1;
r.data.resize(s + 1);
for(std::size_t i = 0; i < lhs.data.size(); ++i){
DoubleUInt c = static_cast<DoubleUInt>(lhs.data[i]) * static_cast<DoubleUInt>(rhs);
for(std::size_t k = 0; i + k < s + 1; ++k){
std::size_t u = i + k;
DoubleUInt v = static_cast<DoubleUInt>(r.data[u]) + c;
UInt a = static_cast<UInt>(v & base_mask);
r.data[u] = a;
c = v >> BitNum;
}
}
r.normalize_data_size();
}
struct kar_range{
typename data_type::const_iterator first, last;
std::size_t size;
};
void kar_mul(const integer &lhs, const integer &rhs){
kar_range tx, ty;
tx.first = lhs.data.begin(), tx.last = lhs.data.end(), tx.size = lhs.data.size();
ty.first = rhs.data.begin(), ty.last = rhs.data.end(), ty.size = rhs.data.size();
kar_rec_mul(tx, ty);
sign = lhs.sign * rhs.sign;
}
void kar_mul_range(const kar_range &lhs, const kar_range &rhs){
data.resize(lhs.size + rhs.size - 1);
if(data.empty()){
normalize_data_size();
return;
}
std::size_t n = 0;
for(typename data_type::const_iterator rhs_iter = rhs.first; rhs_iter != rhs.last; ++rhs_iter, ++n){
UInt rhs_value = *rhs_iter, lhs_value;
std::size_t m = 0;
for(typename data_type::const_iterator lhs_iter = lhs.first; lhs_iter != lhs.last; ++lhs_iter, ++m){
lhs_value = *lhs_iter;
DoubleUInt a = static_cast<DoubleUInt>(rhs_value) * static_cast<DoubleUInt>(lhs_value);
unsigned_single_add(static_cast<UInt>(a & base_mask), n + m);
UInt temp = static_cast<UInt>((a >> BitNum) & base_mask);
if(temp > 0){
unsigned_single_add(temp, n + m + 1);
}
}
}
}
bool kar_less(typename data_type::const_iterator &rhs_first, typename data_type::const_iterator &rhs_last) const{
std::size_t lhs_n = data.size(), rhs_n = std::distance(rhs_first, rhs_last);
typename data_type::const_iterator lhs_first = data.begin(), lhs_iter = lhs_first, rhs_iter = rhs_first;
if(lhs_n < rhs_n){
std::size_t n = rhs_n - lhs_n;
for(std::size_t i = 0; i < n; ++i){
if(*(rhs_iter + (rhs_n - i - 1)) > 0){
return true;
}
}
rhs_iter += lhs_n - 1;
lhs_iter = lhs_first + (lhs_n - 1);
}else if(lhs_n > rhs_n){
std::size_t n = lhs_n - rhs_n;
for(std::size_t i = 0; i < n; ++i){
if(*(lhs_iter + (lhs_n - i - 1)) > 0){
return false;
}
lhs_iter += rhs_n - 1;
rhs_iter = rhs_iter + (rhs_n - 1);
}
}else{
lhs_iter += lhs_n - 1;
rhs_iter += rhs_n - 1;
}
for(; ; ){
UInt l = *lhs_iter, r = *rhs_iter;
if(l < r){ return true; }
if(l > r){ return false; }
if(lhs_iter == lhs_first){
break;
}
--lhs_iter, --rhs_iter;
}
return false;
}
void kar_add(typename data_type::const_iterator &rhs_first, typename data_type::const_iterator &rhs_last){
std::size_t rhs_size = st::distance(rhs_first, rhs_last);
if(data.resize() < rhs_size){
data.resize(rhs_size);
}
typename data_type::iterator operand_iter = data.begin();
DoubleUInt c = 0;
typename data_type::const_iterator iter = rhs_first, end = rhs_last;
--end;
for(; iter != end; ++iter, ++operand_iter){
UInt &operand(*operand_iter);
DoubleUInt s = operand + *iter + c;
operand = static_cast<UInt>(s & base_mask);
c = s >> BitNum;
}
{
UInt &operand(*operand_iter);
DoubleUInt s = static_cast<DoubleUInt>(operand) + static_cast<DoubleUInt>(*iter) + c;
operand = static_cast<UInt>(s & base_mask);
c = s >> BitNum;
if(c > 0){
data.push_back(static_cast<UInt>(c & base_mask));
}
}
}
void kar_sub(typename data_type::const_iterator &rhs_first, typename data_type::const_iterator &rhs_last){
if(kar_less(rhs_first, rhs_last)){
integer s(*this);
data.assign(rhs_first, rhs_last);
unsigned_sub(s);
sign *= - 1;
normalize_data_size();
}else{
std::size_t n = std::distance(rhs_first, rhs_last);
if(n > data.size()){
unsigned_sub(rhs_first, rhs_first + data.size());
}else{
unsigned_sub(rhs_first, rhs_last);
}
normalize_data_size();
}
}
void kar_rec_mul(const kar_range &x, const kar_range &y){
std::size_t n = ceil_pow2_single((std::max)(x.size, y.size));
if(n < 2){
if(x.size > 0 && y.size > 0){
kar_mul_range(x, y);
normalize_data_size();
}
return;
}
n /= 2;
std::size_t xn = x.size < n ? x.size : n, yn = y.size < n ? y.size : n;
kar_range x0, y0, x1, y1;
x0.first = x.first, x0.last = x.first + xn, x0.size = xn;
x1.first = x.first + xn, x1.last = x.last, x1.size = x.size - xn;
y0.first = y.first, y0.last = y.first + yn, y0.size = yn;
y1.first = y.first + yn, y1.last = y.last, y1.size = y.size - yn;
integer z2, z0;
{
integer tx, ty;
if(x1.first != x1.last){
tx.data.assign(x1.first, x1.last);
tx.kar_sub(x0.first, x0.last);
}else{
tx.data.assign(x0.first, x0.last);
tx.sign = - 1;
}
if(y1.first != y1.last){
ty.data.assign(y1.first, y1.last);
ty.kar_sub(y0.first, y0.last);
}else{
ty.data.assign(y0.first, y0.last);
ty.sign = - 1;
}
kar_range rx, ry;
rx.first = tx.data.begin(), rx.last = tx.data.end(), rx.size = tx.data.size();
ry.first = ty.data.begin(), ry.last = ty.data.end(), ry.size = ty.data.size();
kar_rec_mul(rx, ry);
sign = tx.sign != ty.sign ? +1 : - 1;
}
z0.kar_rec_mul(x0, y0);
add(z0);
if(x.size >= n && y.size >= n){
z2.kar_rec_mul(x1, y1);
add(z2);
}
elemental_shift(n);
add(z0);
if(x.size >= n && y.size >= n){
add(z2, n * 2);
}
normalize_data_size();
}
struct quo_rem{
integer quo, rem;
quo_rem() = default;
quo_rem(const quo_rem&) = default;
quo_rem(quo_rem &&other) : quo(std::move(other.quo)), rem(std::move(other.rem)){}
~quo_rem() = default;
};
static quo_rem unsigned_div(const integer &lhs, const integer &rhs){
quo_rem qr;
qr.quo.data.reserve(lhs.data.size());
qr.rem.data.reserve(rhs.data.size());
qr.rem.data.push_back(0);
for(std::size_t i = lhs.data.size() * static_cast<std::size_t>(BitNum) - 1; i + 1 > 0; --i){
if(!qr.rem.data.empty()){
qr.rem <<= 1;
}else{
qr.rem.data.push_back(0);
}
qr.rem.data[0] |= (lhs.data[i / static_cast<std::size_t>(BitNum)] >> (i % static_cast<std::size_t>(BitNum))) & 1;
if(unsigned_less_eq(rhs.data, qr.rem.data)){
qr.rem.unsigned_sub(rhs);
std::size_t t = i / static_cast<std::size_t>(BitNum);
if(qr.quo.data.size() < t + 1){ qr.quo.data.resize(t + 1); }
qr.quo.data[t] |= 1 << (i % BitNum);
}
}
return qr;
}
static quo_rem div(const integer &lhs, const integer &rhs){
quo_rem qr = unsigned_div(lhs, rhs);
qr.quo.sign = qr.rem.sign = lhs.sign == rhs.sign ? +1 : - 1;
return qr;
}
static UInt ceil_pow2_single(UInt x){
std::size_t n = 0, m = 0;
while(x > 0){
if((x & 1) > 0){
++n;
}
++m;
x >>= 1;
}
if(n == 1){
return static_cast<UInt>(1 << (m - 1));
}
return static_cast<UInt>(1 << m);
}
void elemental_shift(std::size_t n){
data.insert(data.begin(), n, 0);
}
std::size_t bit_num() const{
return bit_num(data.back());
}
static std::size_t bit_num(UInt x){
std::size_t n = 0;
while(x > 0){
++n;
x >>= 1;
}
return n;
}
std::size_t finite_bit_shift_lsr(std::size_t n){
std::size_t
digit = n / static_cast<std::size_t>(BitNum),
shift = n % static_cast<std::size_t>(BitNum),
rev_shift = static_cast<std::size_t>(BitNum) - shift;
UInt c = 0;
for(std::size_t i = 0; i < data.size(); ++i){
UInt x = data[i];
data[i] = (x << shift) | c;
if(rev_shift < BitNum){
c = x >> rev_shift;
}else{
c = 0;
}
}
if(c > 0){ data.push_back(c); }
return digit;
}
void bit_shift_lsr(std::size_t n){
elemental_shift(finite_bit_shift_lsr(n));
}
void bit_shift_rsl(std::size_t n){
std::size_t
digit = n / static_cast<std::size_t>(BitNum),
shift = n % static_cast<std::size_t>(BitNum),
rev_shift = BitNum - shift,
size = data.size();
if(digit > 0){
for(std::size_t i = 0, length = size - digit; i < length; ++i){
data[i] = data[i + digit];
}
for(std::size_t i = 0; i < digit; ++i){
data[size - i - 1] = 0;
}
}
UInt c = 0;
for(std::size_t i = 0; i < data.size(); ++i){
std::size_t j = size - i - 1;
UInt x = data[j];
data[j] = (x >> shift) | c;
if(rev_shift < BitNum){
x = x << rev_shift;
}else{
c = 0;
}
}
normalize_data_size();
}
void normalize_data_size(){
if(data.empty()){
sign = +1;
return;
}
std::size_t i = data.size() - 1;
for(; ; ){
if(i == 0 && data.front() == 0){
data.clear();
sign = +1;
return;
}
if(data[i] == 0){
--i;
}else{
break;
}
}
data.resize(i + 1);
}
#define MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(assign_op, op, type) \
integer &operator assign_op(const type &rhs){ *this = *this op integer(rhs); return *this; }
MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(+= , +, int);
MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(-= , -, int);
MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(*= , *, int);
MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(/= , /, int);
MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(%= , %, int);
MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(+= , +, unsigned int);
MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(-= , -, unsigned int);
MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(*= , *, unsigned int);
MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(/= , /, unsigned int);
MULTI_PRECISION_ASSIGN_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(%= , %, unsigned int);
};
#define MULTI_PRECISION_COMPARE_OPERATOR(op, type) \
template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum> \
bool operator op(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const type &rhs){ \
return lhs op integer(rhs); \
} \
template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum> \
bool operator op(const type &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){ \
return integer(lhs) op rhs; \
}
MULTI_PRECISION_COMPARE_OPERATOR(==, int);
//MULTI_PRECISION_COMPARE_OPERATOR(!=, int);
MULTI_PRECISION_COMPARE_OPERATOR(<=, int);
MULTI_PRECISION_COMPARE_OPERATOR(>=, int);
MULTI_PRECISION_COMPARE_OPERATOR(==, unsigned int);
//MULTI_PRECISION_COMPARE_OPERATOR(!=, other.sign){ return true; }
if(sign unsigned int);
MULTI_PRECISION_COMPARE_OPERATOR(<=, unsigned int);
MULTI_PRECISION_COMPARE_OPERATOR(>=, unsigned int);
template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
integer<UInt, DoubleUInt, DoubleInt, BitNum> operator +(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){
integer<UInt, DoubleUInt, DoubleInt, BitNum> t(lhs);
t.add(rhs);
return t;
}
template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
integer<UInt, DoubleUInt, DoubleInt, BitNum> operator -(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){
integer<UInt, DoubleUInt, DoubleInt, BitNum> t(lhs);
t.sub(rhs);
return t;
}
template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
integer<UInt, DoubleUInt, DoubleInt, BitNum> operator *(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){
return integer<UInt, DoubleUInt, DoubleInt, BitNum>::mul(lhs, rhs);
}
template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
integer<UInt, DoubleUInt, DoubleInt, BitNum> operator /(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){
return integer<UInt, DoubleUInt, DoubleInt, BitNum>::div(lhs, rhs).quo;
}
template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
integer<UInt, DoubleUInt, DoubleInt, BitNum> operator %(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){
return integer<UInt, DoubleUInt, DoubleInt, BitNum>::div(lhs, rhs).rem;
}
template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
integer<UInt, DoubleUInt, DoubleInt, BitNum> operator <<(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, UInt n){
integer<UInt, DoubleUInt, DoubleInt, BitNum> x(lhs);
x.bit_shift_lsr(n);
return x;
}
template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
integer<UInt, DoubleUInt, DoubleInt, BitNum> operator >>(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, UInt n){
integer<UInt, DoubleUInt, DoubleInt, BitNum> x(lhs);
x.bit_shift_rsl(n);
return x;
}
template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
integer<UInt, DoubleUInt, DoubleInt, BitNum> operator +=(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){
add(rhs);
return *this;
}
template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
integer<UInt, DoubleUInt, DoubleInt, BitNum> operator -=(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){
sub(rhs);
return *this;
}
template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
integer<UInt, DoubleUInt, DoubleInt, BitNum> operator *=(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){
*this = integer<UInt, DoubleUInt, DoubleInt, BitNum>::mul(lhs, rhs);
return *this;
}
template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
integer<UInt, DoubleUInt, DoubleInt, BitNum> &operator /=(integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){
lhs = integer<UInt, DoubleUInt, DoubleInt, BitNum>::div(lhs, rhs).quo;
return lhs;
}
template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
integer<UInt, DoubleUInt, DoubleInt, BitNum> operator %=(integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){
lhs = integer<UInt, DoubleUInt, DoubleInt, BitNum>::div(lhs, rhs).rem;
return lhs;
}
template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
integer<UInt, DoubleUInt, DoubleInt, BitNum> operator <<=(integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, std::size_t n){
integer<UInt, DoubleUInt, DoubleInt, BitNum> x(lhs);
x.bit_shift_lsr(n);
lhs = std::move(x);
return lhs;
}
template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
integer<UInt, DoubleUInt, DoubleInt, BitNum> operator >>=(integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, std::size_t n){
integer<UInt, DoubleUInt, DoubleInt, BitNum> x(lhs);
x.bit_shift_rsl(n);
lhs = std::move(x);
return lhs;
}
#define MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(op, type) \
template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum> \
integer<UInt, DoubleUInt, DoubleInt, BitNum> operator op(const integer<UInt, DoubleUInt, DoubleInt, BitNum> &lhs, const type &rhs){ return lhs op integer<UInt, DoubleUInt, DoubleInt, BitNum>(rhs); } \
template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum> \
integer<UInt, DoubleUInt, DoubleInt, BitNum> operator op(const type &lhs, const integer<UInt, DoubleUInt, DoubleInt, BitNum> &rhs){ return integer<UInt, DoubleUInt, DoubleInt, BitNum>(lhs) op rhs; }
MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(+, int);
MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(-, int);
MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(*, int);
MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(/ , int);
MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(%, int);
MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(+, unsigned int);
MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(-, unsigned int);
MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(*, unsigned int);
MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(/ , unsigned int);
MULTI_PRECISION_OPERATOR_OVERLOAD_FOR_EXPLICIT_TYPE(%, unsigned int);
template<class UInt, class DoubleUInt, class DoubleInt, DoubleUInt BitNum>
std::ostream &operator <<(
std::ostream &os,
integer<UInt, DoubleUInt, DoubleInt, BitNum> rhs
){
using integer = integer<UInt, DoubleUInt, DoubleInt, BitNum>;
std::vector<std::string> rseq;
integer lo(10);
for(; rhs.data.size() > 0; rhs /= lo){
integer temp = rhs % lo;
if(temp != 0){
rseq.push_back(std::to_string(temp.data[0]));
}else{
rseq.push_back(std::to_string(0));
}
}
os << (rhs.sign > 0 ? "" : "-");
if(!rseq.empty()){
for(auto iter = rseq.rbegin(); iter != rseq.rend(); ++iter){
os << *iter;
}
}else{
os << "0";
}
return os;
}
}
#endif // MULTI_PRECISION_INCLUDE_INTEGER_HPP