はじめに
AVX-512にはvrcp28pd
という、28bit精度の逆数近似命令がある。AVX-512Fには含まれておらず、AVX-512ER拡張が必要。これが本当に28bitの精度があるか、また一回の補正でフル精度になるか確認してみる。
コードは以下においておく。
無補正
まずは無補正で精度の確認。
double out[8];
int
main(){
std::mt19937 mt(1);
std::uniform_real_distribution<double> ud(0.0,1.0);
for(int i=0;i<10;i++){
double a = ud(mt);
__m512d z = _mm512_set1_pd(a);
__m512d zinv = _mm512_rcp28_pd(z);
double ainv = 1.0/a;
_mm512_storeu_pd(out, zinv);
bitdump(ainv);
bitdump(out[0]);
printf("%d\n",bitcomp(ainv,out[0]));
printf("\n");
}
}
一様乱数を10個作って、それを普通に逆数にした値と、逆数近似した場合の値を調べる。
bitdumpはdouble
のビット表現をダンプする関数で、中身はこんな感じ。
void
bitdump(double a){
printf("%.16f\n",a);
char *x = (char*)(&a);
int count = 0;
for(int i=0;i<8;i++){
for(int j=0;j<8;j++){
if(x[7-i] & (1<<(7-j)))printf("1");
else printf("0");
count++;
if(count==1 || count==12)printf(" ");
}
}
printf("\n");
}
bitcompは、2つのdoubleを受け取って、仮数部が何ビット一致するか調べる関数1。実際には符号ビットとか指数部とかも一緒に比較して(それらは一致すると仮定して)、最後に12ビット分引いてる。
int
bitcomp(double a, double b){
char *x = (char*)(&a);
char *y = (char*)(&b);
int sum = 0;
for(int i=0;i<8;i++){
for(int j=0;j<8;j++){
char xb = x[7-i] & (1<<(7-j));
char yb = y[7-i] & (1<<(7-j));
if(xb!=yb)return sum-12;
sum++;
}
}
return sum - 12;
}
実行結果はこんな感じ。
1.0028231394486751
0 01111111111 0000000010111001000001000110101110011010101001101100
1.0028231400601602
0 01111111111 0000000010111001000001000110111000111010111111001000
29
(snip)
1.1815988407141602
0 01111111111 0010111001111101010000101111100111011011110111000110
1.1815988415114642
0 01111111111 0010111001111101010000101111110101001000100000011000
29
これは、フル精度でやったら「1.0028231394486751」となるべき値が、vrcp28pd
使ったら「1.0028231400601602」になって、仮数部は29ビット一致している、という意味。
たまに20bitの一致とかでることがある。
1.1346335406111951
0 01111111111 0010001001110111010101111111110111011110101010010101
1.1346335413761803
0 01111111111 0010001001110111010110000000000100100111110001011000
0.881342
20
これはたまたま21〜28ビットが1で、29bit目に誤差で1ビット足されてばたばたとビットが変わってしまったからで、ちゃんと28bitの精度が出ている。そういう意味では、ちゃんと引き算して誤差を見ないとだめなんだけど、まぁ気にしないことにする。
補正公式の導出
28bitの逆数が得られたら、一回補正すると精度が倍になってほぼフル精度になるはずである。それを調べる。
ある数$a$の逆数近似の値$x$が得られたとする。本来ある数とその逆数をかけたら1になるはずだが、これは近似された値なので誤差があるだろう。それを$\epsilon$とすると、
$$
a x = 1 + \epsilon
$$
となる。これを変形して$1-\epsilon$を作ろう。
$$
2 - a x = 1 - \epsilon
$$
先の式と辺々かけると、
$$
a x (2 - a x) = 1 - \epsilon^2
$$
これは、$x(2 - a x)$が$O(\epsilon^2)$の精度の逆数になったことを意味する。
次に、これがニュートン法になっていることを示す。
ニュートン法は、ある関数$f(x)=0$となる点$x$を探す問題で、近似値$x_n$が与えられた時に、次の値$x_{n+1}$を
$$
x_{n+1} = x_n - \frac{f(x_n)}{f'(x_n)}
$$
として求める手法である。いま、$a$の逆数が知りたいのであるから、
$$
f(x) = \frac{1}{ax} - 1
$$
と定義すれば、$f(x) = 0$となる$x$が求めたい$a$の逆数となる2。
代入すると、
$$
x_{n+1} = x_n(2 - a x_n)
$$
となり、先程の補正公式と一致する。
逆数近似+補正
というわけで逆数近似に一度補正をかけてみる。
int
main(){
std::mt19937 mt(1);
__m512d v2 = _mm512_set1_pd(2.0);
std::uniform_real_distribution<double> ud(0.0,1.0);
for(int i=0;i<10;i++){
double a = ud(mt);
__m512d z = _mm512_set1_pd(a);
__m512d zinv = _mm512_rcp28_pd(z);
__m512d ze = _mm512_mul_pd(z,zinv);
__m512d zinv2 = _mm512_sub_pd(v2,ze);
zinv2 = _mm512_mul_pd(zinv2,zinv);
double ainv = 1.0/a;
_mm512_storeu_pd(out, zinv2);
bitdump(ainv);
bitdump(out[0]);
printf("%d\n",bitcomp(ainv,out[0]));
printf("\n");
}
}
zinv
が逆数近似の結果、zinv2
が一度補正をかけたもの。実行結果はこんな感じ。
1.0028231394486751
0 01111111111 0000000010111001000001000110101110011010101001101100
1.0028231394486751
0 01111111111 0000000010111001000001000110101110011010101001101100
52
(snip)
1.1815988407141602
0 01111111111 0010111001111101010000101111100111011011110111000110
1.1815988407141602
0 01111111111 0010111001111101010000101111100111011011110111000110
52
というわけで、ちゃんとフル精度出てるみたいですね。
まとめ
AVX-512ERで追加されたvrcp28pd
命令と、その精度補正を試してみた。フル精度の除算が欲しい場合、近似+補正公式にバラした方が早くなるかどうかはコードによると思うが(まだ試してない)、ループの形によっては結構効くかも?