AVX512での逆数近似と精度補正

  • 2
    いいね
  • 0
    コメント

はじめに

AVX-512にはvrcp28pdという、28bit精度の逆数近似命令がある。AVX-512Fには含まれておらず、AVX-512ER拡張が必要。これが本当に28bitの精度があるか、また一回の補正でフル精度になるか確認してみる。

コードは以下においておく。

https://github.com/kaityo256/rcp28_sample

無補正

まずは無補正で精度の確認。

test.cpp
double out[8];

int
main(){
  std::mt19937 mt(1);
  std::uniform_real_distribution<double> ud(0.0,1.0);
  for(int i=0;i<10;i++){
    double a = ud(mt);
    __m512d z = _mm512_set1_pd(a);
    __m512d zinv = _mm512_rcp28_pd(z);
    double ainv = 1.0/a;
    _mm512_storeu_pd(out, zinv);
    bitdump(ainv);
    bitdump(out[0]);
    printf("%d\n",bitcomp(ainv,out[0]));
    printf("\n");
  }
}

一様乱数を10個作って、それを普通に逆数にした値と、逆数近似した場合の値を調べる。

bitdumpはdoubleのビット表現をダンプする関数で、中身はこんな感じ。

void
bitdump(double a){
  printf("%.16f\n",a);
  char *x = (char*)(&a);
  int count = 0;
  for(int i=0;i<8;i++){
    for(int j=0;j<8;j++){
      if(x[7-i] & (1<<(7-j)))printf("1");
      else printf("0");
      count++;
      if(count==1 || count==12)printf(" ");
    }
  }
  printf("\n");
}

bitcompは、2つのdoubleを受け取って、仮数部が何ビット一致するか調べる関数1。実際には符号ビットとか指数部とかも一緒に比較して(それらは一致すると仮定して)、最後に12ビット分引いてる。

int
bitcomp(double a, double b){
  char *x = (char*)(&a);
  char *y = (char*)(&b);
  int sum = 0;
  for(int i=0;i<8;i++){
    for(int j=0;j<8;j++){
      char xb = x[7-i] & (1<<(7-j));
      char yb = y[7-i] & (1<<(7-j));
      if(xb!=yb)return sum-12;
      sum++;
    }
  }
  return sum - 12;
}

実行結果はこんな感じ。

1.0028231394486751
0 01111111111 0000000010111001000001000110101110011010101001101100
1.0028231400601602
0 01111111111 0000000010111001000001000110111000111010111111001000
29
(snip)
1.1815988407141602
0 01111111111 0010111001111101010000101111100111011011110111000110
1.1815988415114642
0 01111111111 0010111001111101010000101111110101001000100000011000
29

これは、フル精度でやったら「1.0028231394486751」となるべき値が、vrcp28pd使ったら「1.0028231400601602」になって、仮数部は29ビット一致している、という意味。

たまに20bitの一致とかでることがある。

1.1346335406111951
0 01111111111 0010001001110111010101111111110111011110101010010101
1.1346335413761803
0 01111111111 0010001001110111010110000000000100100111110001011000
0.881342
20

これはたまたま21〜28ビットが1で、29bit目に誤差で1ビット足されてばたばたとビットが変わってしまったからで、ちゃんと28bitの精度が出ている。そういう意味では、ちゃんと引き算して誤差を見ないとだめなんだけど、まぁ気にしないことにする。

補正公式の導出

28bitの逆数が得られたら、一回補正すると精度が倍になってほぼフル精度になるはずである。それを調べる。

ある数$a$の逆数近似の値$x$が得られたとする。本来ある数とその逆数をかけたら1になるはずだが、これは近似された値なので誤差があるだろう。それを$\epsilon$とすると、

$$
a x = 1 + \epsilon
$$

となる。これを変形して$1-\epsilon$を作ろう。

$$
2 - a x = 1 - \epsilon
$$

先の式と辺々かけると、

$$
a x (2 - a x) = 1 - \epsilon^2
$$

これは、$x(2 - a x)$が$O(\epsilon^2)$の精度の逆数になったことを意味する。

次に、これがニュートン法になっていることを示す。

ニュートン法は、ある関数$f(x)=0$となる点$x$を探す問題で、近似値$x_n$が与えられた時に、次の値$x_{n+1}$を

$$
x_{n+1} = x_n - \frac{f(x_n)}{f'(x_n)}
$$

として求める手法である。いま、$a$の逆数が知りたいのであるから、

$$
f(x) = \frac{1}{ax} - 1
$$

と定義すれば、$f(x) = 0$となる$x$が求めたい$a$の逆数となる2

代入すると、

$$
x_{n+1} = x_n(2 - a x_n)
$$

となり、先程の補正公式と一致する。

逆数近似+補正

というわけで逆数近似に一度補正をかけてみる。

test2.cpp
int
main(){
  std::mt19937 mt(1);
  __m512d v2 = _mm512_set1_pd(2.0);
  std::uniform_real_distribution<double> ud(0.0,1.0);
  for(int i=0;i<10;i++){
    double a = ud(mt);
    __m512d z = _mm512_set1_pd(a);
    __m512d zinv = _mm512_rcp28_pd(z);
    __m512d ze = _mm512_mul_pd(z,zinv);
    __m512d zinv2 = _mm512_sub_pd(v2,ze);
    zinv2 = _mm512_mul_pd(zinv2,zinv);
    double ainv = 1.0/a;
    _mm512_storeu_pd(out, zinv2);
    bitdump(ainv);
    bitdump(out[0]);
    printf("%d\n",bitcomp(ainv,out[0]));
    printf("\n");
  }
}

zinvが逆数近似の結果、zinv2が一度補正をかけたもの。実行結果はこんな感じ。

1.0028231394486751
0 01111111111 0000000010111001000001000110101110011010101001101100
1.0028231394486751
0 01111111111 0000000010111001000001000110101110011010101001101100
52
(snip)
1.1815988407141602
0 01111111111 0010111001111101010000101111100111011011110111000110
1.1815988407141602
0 01111111111 0010111001111101010000101111100111011011110111000110
52

というわけで、ちゃんとフル精度出てるみたいですね。

まとめ

AVX-512ERで追加されたvrcp28pd命令と、その精度補正を試してみた。フル精度の除算が欲しい場合、近似+補正公式にバラした方が早くなるかどうかはコードによると思うが(まだ試してない)、ループの形によっては結構効くかも?


  1. ビット演算とか使うともっときれいにできると思うけど手抜き。 

  2. ここで$f(x) = ax -1$としても$f(x) =0$の解が求める逆数となるが、これは線形関数なのでニュートン法がうまく働かず、自明な結果しか出てこない。