LoginSignup
2
1

GNU MPはなぜ速いのか

Last updated at Posted at 2024-06-14

フィボナッチをC++で計算していたとき、

並み居るC++のライブラリより桁違いにGNU MPことGMPが速かったわけですが、GNU MPはなぜ速いのか?を調べてみました。結論から言うと アセンブラ使ってるから なのですが、以降調べたことを書いてみます。

自作の無制限桁整数と比較する

自作fibonacci動かすためだけに特化してコードを小さくした無制限桁数(符号なし)整数

biguint.h

#pragma once

#include <vector>
#include <cinttypes>

template <typename T>
struct biguint {
    std::vector<T> v;
    typedef T value_type;
    biguint(value_type v = 0) : v(1, v) {}
    biguint(const biguint &right) : v(right.v) {}
    biguint<value_type> &operator=(const biguint<value_type> &right) {
        if (this != &right) 
            this->v = right.v;
        return *this;
    }
};

template <typename T>
biguint<T> operator+(const biguint<T>& left, const biguint<T>& right) {
    size_t len = left.v.size() > right.v.size() ? left.v.size() : right.v.size();
    biguint<T> result;
    result.v.resize(len);
    bool carry = false;
    for (size_t i = 0; i < len; ++i) {
        auto r = i < left.v.size() ? left.v[i] : 0;
        r += carry;
        carry = r < carry;
        auto b = i < right.v.size() ? right.v[i] : 0;
        r += b;
        carry = carry || r < b;
        result.v[i] = r;
    }
    if (carry)
        result.v.push_back(1);
    return result;
}

template <typename OUT, typename T>
OUT& operator<<(OUT& out, biguint<T>& val) {
    out << "[";
    bool first = true;        
    for (auto e: val.v)  {
        if (first) first = false;
        else out << ",";
        out << e;
    }
    out << "]";
    return out;
}

test.cpp

#include <iostream>
#include <chrono>
#include <gmpxx.h>
#include "biguint.h"

template <typename T>
T my_fibonacci(int N, T v[] = nullptr) {
    if (N < 2) 
        return N;
    T local_v[] = {0, 1};
    int s;
    if (v) {
        s = N;
    } else {
        v = local_v; 
        s = 2;
    }
    for (int i = s; i <= N; ++i) {
        T v2 = v[0] + v[1];
        v[0] = v[1];
        v[1] = v2;
    }
    return v[1];
}

using namespace std;
using namespace std::chrono;

template<typename T, typename FUNC, typename RESULT, typename... ARGS>
T measure(FUNC f, RESULT* pr, ARGS... args) {
    auto s = high_resolution_clock::now();
    auto tmp = f(args...);
    auto e = high_resolution_clock::now();
    if (pr)
        *pr = tmp;
    return duration_cast<T>(e -s);
}

template<typename FUNC>
void test(const char* name, FUNC f) {
    const int U = 100000;
    for (int i = U; i < U * 10 + 1; i += U) {
        decltype(f(0, nullptr)) r;
        auto t = measure<nanoseconds>(f, &r, i, nullptr);
        std::cout << name << "(" << i << ")" /*<< "->" << r*/ << ",time[sec]: " << t.count() / 1000000000. << std::endl;
    }
}

int main() {
     test("biguint", &my_fibonacci<biguint<uint64_t>>);
     test("GMP", &my_fibonacci<mpz_class>);
    return 0;
}

コンパイルして…

g++ -O3 test.cpp -lgmpxx -lgmp -o test

実行した結果がコレです。

biguint(100000),time[sec]: 0.143848
biguint(200000),time[sec]: 0.575212
biguint(300000),time[sec]: 1.23255
biguint(400000),time[sec]: 2.19872
biguint(500000),time[sec]: 3.50379
biguint(600000),time[sec]: 5.1232
biguint(700000),time[sec]: 6.91124
biguint(800000),time[sec]: 9.02167
biguint(900000),time[sec]: 11.3196
biguint(1000000),time[sec]: 13.9295
GMP(100000),time[sec]: 0.0784611
GMP(200000),time[sec]: 0.315758
GMP(300000),time[sec]: 0.67119
GMP(400000),time[sec]: 1.21538
GMP(500000),time[sec]: 1.8671
GMP(600000),time[sec]: 2.70328
GMP(700000),time[sec]: 3.62083
GMP(800000),time[sec]: 4.78603
GMP(900000),time[sec]: 5.91166
GMP(1000000),time[sec]: 7.57548

こんな決め打ちで何もしてないコードと比較しても2倍近く速いわけです。C++だけで速くなるのか?と疑問が湧きます。

GMPのコードを調べる

入手と展開

$ wget 'https://gmplib.org/download/gmp/gmp-6.3.0.tar.xz'
$ tar xvf gmp-6.3.0.tar.xz
$ cd gmp-6.3.0

基本方針

fibonacciで時間がかかるのは、Nが大きくなって膨大な桁になった数値のコピー加算です。そんなことは見なくても分かるでしょう。コピーについてはどうせ最後はmemcpyなので差は出ないとしても、加算は工夫の余地があります。なので「加算のコード」を調べるのが基本方針になります。

includeしているgmpxx.hでmpz_classを探す

これはすぐ見つかります。

// gmpxx.h:1756
typedef __gmp_expr<mpz_t, mpz_t> mpz_class;

__gmp_expr<mpz_t, mpz_t>を探す

__gmp_exprはテンプレート引数にバリエーションがありますが、実際に使用されてるのはこれです。

// gmpxx.h:1572
template <>
class __gmp_expr<mpz_t, mpz_t>
{

見れば分かりますがやはり2項演算の加算はクラス内にはありません。

__gmp_expr<mpz_t, mpz_t>operator+を探す

探すとoperator+は2つしかなく二項なのはコレのようです。

// gmpxx.h:3318
__GMP_DEFINE_BINARY_FUNCTION(operator+, __gmp_binary_plus)

__GMP_DEFINE_BINARY_FUNCTIONマクロを探す

マクロになると普通に追えなくなるので大変です。

// gmpxx.h:3022
#define __GMP_DEFINE_BINARY_FUNCTION(fun, eval_fun) \
__GMPP_DEFINE_BINARY_FUNCTION(fun, eval_fun)        \
__GMPN_DEFINE_BINARY_FUNCTION(fun, eval_fun)

それぞれ展開して一致することを願います。

__GMP[PN]_DEFINE_BINARY_FUNCTIONを探す

// gmpxx.h:2965
#define __GMPP_DEFINE_BINARY_FUNCTION(fun, eval_fun)                   \
                                                                       \
template <class T, class U, class V, class W>                          \
inline __gmp_expr<typename __gmp_resolve_expr<T, V>::value_type,       \
__gmp_binary_expr<__gmp_expr<T, U>, __gmp_expr<V, W>, eval_fun> >      \
fun(const __gmp_expr<T, U> &expr1, const __gmp_expr<V, W> &expr2)      \
{                                                                      \
  return __gmp_expr<typename __gmp_resolve_expr<T, V>::value_type,     \
     __gmp_binary_expr<__gmp_expr<T, U>, __gmp_expr<V, W>, eval_fun> > \
    (expr1, expr2);                                                    \
}

// gmpxx.h:3009
#define __GMPN_DEFINE_BINARY_FUNCTION(fun, eval_fun)              \
__GMPNS_DEFINE_BINARY_FUNCTION(fun, eval_fun, signed char)        \
__GMPNU_DEFINE_BINARY_FUNCTION(fun, eval_fun, unsigned char)      \
__GMPNS_DEFINE_BINARY_FUNCTION(fun, eval_fun, signed int)         \
__GMPNU_DEFINE_BINARY_FUNCTION(fun, eval_fun, unsigned int)       \
__GMPNS_DEFINE_BINARY_FUNCTION(fun, eval_fun, signed short int)   \
__GMPNU_DEFINE_BINARY_FUNCTION(fun, eval_fun, unsigned short int) \
__GMPNS_DEFINE_BINARY_FUNCTION(fun, eval_fun, signed long int)    \
__GMPNU_DEFINE_BINARY_FUNCTION(fun, eval_fun, unsigned long int)  \
__GMPND_DEFINE_BINARY_FUNCTION(fun, eval_fun, float)              \
__GMPND_DEFINE_BINARY_FUNCTION(fun, eval_fun, double)             \
/* __GMPNLD_DEFINE_BINARY_FUNCTION(fun, eval_fun, long double) */

どうやら該当するのはPの方で、T,U,V,Wが全てmpz_tとしたときのようです。つまりこう。

__gmp_expr<typename __gmp_resolve_expr<mpz_t, mpz_t>::value_type,       
__gmp_binary_expr<__gmp_expr<mpz_t, mpz_t>, __gmp_expr<mpz_t, mpz_t>, __gmp_binary_plus> >      
operator+(const __gmp_expr<mpz_t, mpz_t> &expr1, const __gmp_expr<mpz_t, mpz_t> &expr2)

戻り値がえらいややこしくなってて、素直に__gmp_exprを返すのではなく、二項演算を含む式として返されて、それに対してoerator=などがexpr()することで__gmp_exprが返る仕組みになっている。

その二項演算を含む式を表す__gmp_exprを探す

// gmpxx.h:2700
template <class T, class U, class V, class Op>
class __gmp_expr
<T, __gmp_binary_expr<__gmp_expr<T, U>, __gmp_expr<T, V>, Op> >
{
private:
  typedef __gmp_expr<T, U> val1_type;
  typedef __gmp_expr<T, V> val2_type;

  __gmp_binary_expr<val1_type, val2_type, Op> expr;
public:
  __gmp_expr(const val1_type &val1, const val2_type &val2)
    : expr(val1, val2) { }
  void eval(typename __gmp_resolve_expr<T>::ptr_type p) const
  {
    __gmp_temp<T> temp2(expr.val2, p);
    expr.val1.eval(p);
    Op::eval(p, p, temp2.__get_mp());
  }

T,U,Vはmpz_tでOpは__gmp_binary_plus。__gmp_binary_exprは実際にはval1,val2の保管先でしかない。
後は代入演算子のeval呼び出しを確認すると…

// gmpxx.h:1678
  template <class T, class U>
  __gmp_expr<value_type, value_type> & operator=(const __gmp_expr<T, U> &expr)
  { __gmp_set_expr(mp, expr); return *this; }
// gmpxx.h:2208
template <class T>
inline void __gmp_set_expr(mpz_ptr z, const __gmp_expr<mpz_t, T> &expr)
{
  expr.eval(z);
}

__gmp_binary_plusのevalを探す

// gmpxx.h:183
struct __gmp_binary_plus
{
  static void eval(mpz_ptr z, mpz_srcptr w, mpz_srcptr v)
  { mpz_add(z, w, v); }

無事CのGMP関数に辿り着いた。

mpz_addを探す

// gmp.h:632
#define mpz_add __gmpz_add
__GMP_DECLSPEC void mpz_add (mpz_ptr, mpz_srcptr, mpz_srcptr);

なんかaliasまで設定されている。

// mpz/add.c:32
#define OPERATION_add
#include "aors.h"

// mpz/aors.h:35
#ifdef OPERATION_add
#define FUNCTION     mpz_add
#define VARIATION
#endif

// mpz/aors.h:49
void
FUNCTION (mpz_ptr w, mpz_srcptr u, mpz_srcptr v)
{
...
      mp_limb_t cy_limb = mpn_add (wp, up, abs_usize, vp, abs_vsize);

mpn_addを探す

// mpn/add.c:31
#define __GMP_FORCE_mpn_add 1

#include "gmp-impl.h"

// gmp.h:1470
#define mpn_add __MPN(add)
#if __GMP_INLINE_PROTOTYPES || defined (__GMP_FORCE_mpn_add)
__GMP_DECLSPEC mp_limb_t mpn_add (mp_ptr, mp_srcptr, mp_size_t, mp_srcptr, mp_size_t);
#endif

// gmp.h:249
#ifndef __MPN
#define __MPN(x) __gmpn_##x
#endif

// gmp.h:2141
#if defined (__GMP_EXTERN_INLINE) || defined (__GMP_FORCE_mpn_add)
...
mp_limb_t
mpn_add (mp_ptr __gmp_wp, mp_srcptr __gmp_xp, mp_size_t __gmp_xsize, mp_srcptr __gmp_yp, mp_size_t __gmp_ysize)
{
  mp_limb_t  __gmp_c;
  __GMPN_ADD (__gmp_c, __gmp_wp, __gmp_xp, __gmp_xsize, __gmp_yp, __gmp_ysize);
  return __gmp_c;
}

__GMPN_ADDを探す

// gmp.h:1961
#define __GMPN_ADD(cout, wp, xp, xsize, yp, ysize)              \
  __GMPN_AORS (cout, wp, xp, xsize, yp, ysize, mpn_add_n,       \
               (((wp)[__gmp_i++] = (__gmp_x + 1) & GMP_NUMB_MASK) == 0))

__GMPN_AORSを探す

// gmp.h:1927
#define __GMPN_AORS(cout, wp, xp, xsize, yp, ysize, FUNCTION, TEST)     \
  do {                                                                  \
    mp_size_t  __gmp_i;                                                 \
    mp_limb_t  __gmp_x;                                                 \
                                                                        \
    /* ASSERT ((ysize) >= 0); */                                        \
    /* ASSERT ((xsize) >= (ysize)); */                                  \
    /* ASSERT (MPN_SAME_OR_SEPARATE2_P (wp, xsize, xp, xsize)); */      \
    /* ASSERT (MPN_SAME_OR_SEPARATE2_P (wp, xsize, yp, ysize)); */      \
                                                                        \
    __gmp_i = (ysize);                                                  \
    if (__gmp_i != 0)                                                   \
      {                                                                 \
        if (FUNCTION (wp, xp, yp, __gmp_i))                             \
          {                                                             \
            do                                                          \
              {                                                         \
                if (__gmp_i >= (xsize))                                 \
                  {                                                     \
                    (cout) = 1;                                         \
                    goto __gmp_done;                                    \
                  }                                                     \
                __gmp_x = (xp)[__gmp_i];                                \
              }                                                         \
            while (TEST);                                               \
          }                                                             \
      }                                                                 \
    if ((wp) != (xp))                                                   \
      __GMPN_COPY_REST (wp, xp, xsize, __gmp_i);                        \
    (cout) = 0;                                                         \
  __gmp_done:                                                           \
    ;                                                                   \
  } while (0)

ここでFUNCTIONは mpn_add_n

mpn_add_nを探す

// gmp.h:1480
#define mpn_add_n __MPN(add_n)
__GMP_DECLSPEC mp_limb_t mpn_add_n (mp_ptr, mp_srcptr, mp_srcptr, mp_size_t);
dnl  mpn/add_n.asm:65
ifdef(`OPERATION_add_n', `
        define(ADCSBB,        adc)
        define(func,          mpn_add_n)
        define(func_nc,       mpn_add_nc)')
ifdef(`OPERATION_sub_n', `
        define(ADCSBB,        sbb)
        define(func,          mpn_sub_n)
        define(func_nc,       mpn_sub_nc)')

MULFUNC_PROLOGUE(mpn_add_n mpn_add_nc mpn_sub_n mpn_sub_nc)

ABI_SUPPORT(DOS64)
ABI_SUPPORT(STD64)

ASM_START()
        TEXT
        ALIGN(16)
PROLOGUE(func_nc)
        FUNC_ENTRY(4)
IFDOS(` mov     56(%rsp), %r8   ')
        mov     R32(n), R32(%rax)
        shr     $2, n
        and     $3, R32(%rax)
        bt      $0, %r8                 C cy flag <- carry parameter
        jrcxz   L(lt4)

        mov     (up), %r8
        mov     8(up), %r9
        dec     n
        jmp     L(mid)

EPILOGUE()
        ALIGN(16)
PROLOGUE(func)
        FUNC_ENTRY(4)
        mov     R32(n), R32(%rax)
        shr     $2, n
        and     $3, R32(%rax)
        jrcxz   L(lt4)

        mov     (up), %r8
        mov     8(up), %r9
        dec     n
        jmp     L(mid)

L(lt4): dec     R32(%rax)
        mov     (up), %r8
        jnz     L(2)
        ADCSBB  (vp), %r8
        mov     %r8, (rp)
        adc     R32(%rax), R32(%rax)
        FUNC_EXIT()
        ret

L(2):   dec     R32(%rax)
        mov     8(up), %r9
        jnz     L(3)
        ADCSBB  (vp), %r8
        ADCSBB  8(vp), %r9
        mov     %r8, (rp)
        mov     %r9, 8(rp)
        adc     R32(%rax), R32(%rax)
        FUNC_EXIT()
        ret

L(3):   mov     16(up), %r10
        ADCSBB  (vp), %r8
        ADCSBB  8(vp), %r9
        ADCSBB  16(vp), %r10
        mov     %r8, (rp)
        mov     %r9, 8(rp)
        mov     %r10, 16(rp)
        setc    R8(%rax)
        FUNC_EXIT()
        ret

        ALIGN(16)
L(top): ADCSBB  (vp), %r8
        ADCSBB  8(vp), %r9
        ADCSBB  16(vp), %r10
        ADCSBB  24(vp), %r11
        mov     %r8, (rp)
        lea     32(up), up
        mov     %r9, 8(rp)
        mov     %r10, 16(rp)
        dec     n
        mov     %r11, 24(rp)
        lea     32(vp), vp
        mov     (up), %r8
        mov     8(up), %r9
        lea     32(rp), rp
L(mid): mov     16(up), %r10
        mov     24(up), %r11
        jnz     L(top)

L(end): lea     32(up), up
        ADCSBB  (vp), %r8
        ADCSBB  8(vp), %r9
        ADCSBB  16(vp), %r10
        ADCSBB  24(vp), %r11
        lea     32(vp), vp
        mov     %r8, (rp)
        mov     %r9, 8(rp)
        mov     %r10, 16(rp)
        mov     %r11, 24(rp)
        lea     32(rp), rp

        inc     R32(%rax)
        dec     R32(%rax)
        jnz     L(lt4)
        adc     R32(%rax), R32(%rax)
        FUNC_EXIT()
        ret
EPILOGUE()

mpn_add_nはアセンブラになっており、主要な計算はアセンブラで書かれているということが分かった。

感想

単純長距離作業がメインになるものはやはり手で書くアセンブラが速く、コンパイラの最適化任せにはならないらしい。どおりでC++のヘッダオンリーライブラリがあまりないわけだ…

2
1
8

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1