LoginSignup
6
4

More than 5 years have passed since last update.

高次の逆数近似補正

Posted at

はじめに

一般に除算は遅いため、高速に逆数の近似値を求め、必要に応じて精度を補正する、ということが行われる。しかし、AVXの逆数近似が12bit精度、AVX-512の逆数近似は28bit精度なので、AVX2でニュートン法3回反復、AVX-512のやつだと一度の補正でフル精度出る。

「これじゃつまらない!もっと高次の補正をしてみたい!」 という人のためにHPC-ACEという命令セットには8bit逆数近似命令がある。これを試してみよう。

8bitの逆数近似の雰囲気

8bitの逆数近似ってわりとすごい(割り切りが)。おそらくテーブル引いてるかなにかしているのだろう。とりあえずフル精度で計算した逆数と、8bit逆数近似で計算した結果のビットダンプを3組くらい表示してみる。

1.002823139448675
0 01111111111 0000000010111001000001000110101110011010101001101100
1.001953125000000
0 01111111111 0000000010000000000000000000000000000000000000000000

1.072320096785134
0 01111111111 0001001010000011100100011110001010001001001000111110
1.070312500000000
0 01111111111 0001001000000000000000000000000000000000000000000000

7.804911688494701
0 10000000001 1111001110000011101011000101000010010000001011110011
7.796875000000000
0 10000000001 1111001100000000000000000000000000000000000000000000

順番にフル精度の逆数、そのビットダンプ、8bit精度の逆数、そのビットダンプと並んでいる。仮数部に8bitしか値がなく、残りがゼロにされていることがわかる。多分テーブルひいてから正規化するとかそういうことをやってるんだと思う。これを存分に補正しよう。

高次の補正公式

ニュートン法でやっても良いのだが、特に逆数近似の補正はテイラー展開でやるほうがきれいなので、そっちでやる。

ある数$a$に関して、逆数近似$\tilde{a}_1$が得られたとしよう。これを1次の近似と呼ぶ。本来ならばある数とその逆数をかけると1になるはずだが、近似値なので誤差が出るだろう。それを$\epsilon$としよう。

a \tilde{a}_1 = 1 + \epsilon

これをこう変形する。

\frac{1}{a} = \frac{\tilde{a}_1}{1 + \epsilon}

左辺が求めたい値であった。ここで右辺を$\epsilon$に関してテイラー展開する。

\frac{1}{a} = \tilde{a}_1(1 - \epsilon + \epsilon^2 - \epsilon^3 + \cdots)

もともと$\epsilon = a \tilde{a}_1 - 1$であったから、あとは好きな次数で止めて代入すれば高次補正公式を作ることができる。

まず、2次公式は

\tilde{a}_2 = \tilde{a}_1(1 - \epsilon) = \tilde{a}_1 (2 - a \tilde{a}_1)

となり、ニュートン法と一致する。また、ニュートン法二回反復

\tilde{a}_4 =  \tilde{a}_2 (2 - a \tilde{a}_2)

は、テイラー展開の4次補正

\tilde{a}_4 =  \tilde{a}_1(1 - \epsilon + \epsilon^2 - \epsilon^3) 

と一致する。ニュートン反復を使うと、2次、4次、8次補正の系列が得られるが、テイラー展開を使えば任意の次数の補正を得ることができる。

補正結果

8bit精度ということは、2次の補正で16bit、3次の補正で24bit・・・となり、フル精度に到達するには7次が必要となる(多分)。それを確認するため、こんなコードを書いてみた。

test.cpp
#include <emmintrin.h>
#include <random>
#include <stdio.h>

#define EPS (a * ainv[1] - 1.0)
#define EPS2 EPS*EPS
#define EPS3 EPS*EPS*EPS
#define EPS4 EPS2*EPS2
#define EPS5 EPS2*EPS3
#define EPS6 EPS3*EPS3

const char *order[7] =  {"1st", "2nd", "3rd", "4th", "5th", "6th", "7th"};
double out[2];

int
main(){
  std::mt19937 mt(1);
  std::uniform_real_distribution<double> ud(0.0,1.0);
  double ainv[8];
  for(int i=0;i<10;i++){
    double a = ud(mt);
    __m128d z = _mm_set_pd(a,a);
    __m128d zinv = __builtin_fj_rcpa_v2r8(z);
    _mm_store_pd(out,zinv);
     ainv[0] = 1.0/a;
     ainv[1] = out[0];
     ainv[2] = ainv[1] * (1.0 - EPS);
     ainv[3] = ainv[1] * (1.0 - EPS + EPS2);
     ainv[4] = ainv[1] * (1.0 - EPS + EPS2 - EPS3);
     ainv[5] = ainv[1] * (1.0 - EPS + EPS2 - EPS3 + EPS4);
     ainv[6] = ainv[1] * (1.0 - EPS + EPS2 - EPS3 + EPS4 - EPS5);
     ainv[7] = ainv[1] * (1.0 - EPS + EPS2 - EPS3 + EPS4 - EPS5 + EPS6);
    for(int j=0;j<8;j++){
      printf("%.15f", ainv[j]);
      if (j==0) printf(" (full) \n");
      else printf(" %02d (%s)\n", bitdiff(ainv[0],ainv[j]),order[j-1]);
    }
    printf("\n");
  }
}

テイラー展開の様子が見やすいように表記を工夫してみた。累乗の扱いがちょっとアレだけど、まぁそこはそれ。

ちなみに、bit単位で一致精度を見る関数bitdiffは、一度64bit整数に落としてから引き算し、残ったbit数を数えている。こんな感じ。

int
bitdiff(double a1, double a2){
  long *xp1 = (long*)(&a1);
  long *xp2 = (long*)(&a2);
  long x = (*xp1 > *xp2)? *xp1 - *xp2 : *xp2-*xp1;
  for(int i=12;i<64;i++){
    if (x & (long(1) << (63-i)))return i - 12;
  }
  return 52;
}

結果はこうなる。

1.002823139448675 (full) 
1.001953125000000 10 (1st)
1.002822384654424 20 (2nd)
1.002823138793842 30 (3rd)
1.002823139448107 40 (4th)
1.002823139448674 50 (5th)
1.002823139448675 51 (6th)
1.002823139448675 51 (7th)

1.072320096785134 (full) 
1.070312500000000 08 (1st)
1.072316338164118 18 (2nd)
1.072320089748247 27 (3rd)
1.072320096771959 36 (4th)
1.072320096785109 45 (5th)
1.072320096785134 51 (6th)
1.072320096785134 51 (7th)

7.804911688494701 (full) 
7.796875000000000 08 (1st)
7.804903413146492 18 (2nd)
7.804911679973606 28 (3rd)
7.804911688485928 38 (4th)
7.804911688494692 48 (5th)
7.804911688494701 52 (6th)
7.804911688494701 52 (7th)

(snip)

1.181598840714160 (full) 
1.179687500000000 09 (1st)
1.181595748951373 18 (2nd)
1.181598835712960 27 (3rd)
1.181598840706070 36 (4th)
1.181598840714147 46 (5th)
1.181598840714160 52 (6th)
1.181598840714160 52 (7th)

試した範囲では、全てのケースにおいて6次補正でフル精度になった。最低8bitの近似なので実際にはもっと精度が高いのと、最初に8bitしか精度がない場合でも、一度補正したら17bit以上の精度が出たりするので、7次の補正が必要なことはないみたい。

まとめ

HPC-ACEの8bit逆数近似とその補正をやってみた。どうでもいいけど、ビットダンプする時に「あわないなぁ」としばらく悩んでいたのだが、原因はエンディアンの違いだった・・・

参考文献

6
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
4