51
34

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

八戸高専Advent Calendar 2023

Day 9

平方根の高速化とマジックナンバー

Posted at

平方根の高速化とマジックナンバー

 
この記事はアドカレに参加しています。

平方根

 $\sqrt{a}$という形が平方根ですね。$a^{\frac{1}{2}}$と書くこともできます。この平方根、グラフィック系のプログラムなんかでめちゃくちゃ見かけますし、それだけ使います。AIの分野でも頻繁に使用するみたいです。

数回程度sqrt関数を呼び出す程度ならあまり気になりませんが、何千回、何万回、何億回もsqrt関数を呼び出すと、その計算速度の遅さに辟易します。

そこで、多少の精度を捨ててでもsqrt関数を高速化しようとしたのがfast sqrt algorithmと呼ばれるアルゴリズムです。この記事はそんなアルゴリズムの紹介です。

fast sqrt algorithmは精度面でも速度面でも最良ではないです。基本的に低レベルな命令セットの方が精度を保ちつつも高速な場合が多いです。(僕の環境では、c++ではstd::sqrtよりfast sqrt algorithmの方が高速に動作しました。wgslではデフォルトで用意されている関数の方が速かったです。)

fast sqrt algorithmの方針

 $\sqrt{a}$という値は以下のように変形することができます。

\begin{align}
\sqrt{a}&=a^{\frac{1}{2}} \\
&=\frac{a^{\frac{1}{2}} \times a^{\frac{1}{2}}}{a^{\frac{1}{2}}} \\
&=\frac{a}{a^{\frac{1}{2}}} \\
&=a \times \frac{1}{\sqrt{a}} \\
\end{align}

$\frac{1}{\sqrt{a}}$という逆平方根を計算した後に、$a$を掛けることで$\sqrt{a}$を求めることができます。つまり、$\frac{1}{\sqrt{a}}$を高速に計算できれば$\sqrt{a}$も高速に計算できることになります。

 逆平方根を高速に求めるプログラムは以下のようなとても短いものです。

//https://en.wikipedia.org/wiki/Fast_inverse_square_root
float q_rsqrt(float number)
{
  long i;
  float x2, y;
  const float threehalfs = 1.5F;

  x2 = number * 0.5F;
  y  = number;
  i  = * ( long * ) &y;                       // evil floating point bit level hacking
  i  = 0x5f3759df - ( i >> 1 );               // what the fuck?
  y  = * ( float * ) &i;
  y  = y * ( threehalfs - ( x2 * y * y ) );   // 1st iteration
  // y  = y * ( threehalfs - ( x2 * y * y ) );   // 2nd iteration, this can be removed

  return y;
}

return y*number;のように書き換えると、平方根になります。

逆平方根$\frac{1}{\sqrt{a}}$はニュートン法で近似しており、たったの二回の近似だけでもかなりの精度です。誤差が気にならない場合は一回の近似だけでも充分です。

ニュートン法

 ニュートン法の導出(?)です。

ニュートン法では関数$f(x)=y$があったとき、$f(x_{0})=0$となるような$x_{0}$を求めます。

まずは$f(x)=y$を微分の形で表わします。

\begin{align}
f'(x_{k})&=y' \\
f'(x_{k})&=\frac{\Delta y}{\Delta x}
\end{align}

このとき、2つの座標$(x_{k},f(x_{k})),(x_{k+1},f(x_{k+1}))$があるとして、$\Delta x=x_{k}-x_{k+1}$とおくと、

\begin{align}
f'(x_{k})&=\frac{f(x_{k})-f(x_{k+1})}{x_{k}-x_{k+1}}
\end{align}

のようになります。$x_{k+1}$が$x_{0}$により近い値であるとすると、$f(x_{k+1})\fallingdotseq 0$となるので、

\begin{align}
f'(x_{k})&=\frac{f(x_{k})-0}{x_{k}-x_{k+1}} \\
f'(x_{k})&=\frac{f(x_{k})}{x_{k}-x_{k+1}} \\
x_{k}-x_{k+1}&=\frac{f(x_{k})}{f'(x_{k})} \\
x_{k+1}&=x_{k}-\frac{f(x_{k})}{f'(x_{k})}
\end{align}

これで、任意の値$x_{k}$から、$f(x_{0})=0$となるような$x_{0}$により近い値$x_{k+1}$を求める式が分かりました。

ニュートン法では以下のように何度も近似を繰り返すことで、$f(x_{0})=0$となるような$x_{0}$を求めます。

\begin{align}
x_{k+1}&=x_{k}-\frac{f(x_{k})}{f'(x_{k})} \\
x_{k+2}&=x_{k+1}-\frac{f(x_{k+1})}{f'(x_{k+1})} \\
x_{k+3}&=x_{k+2}-\frac{f(x_{k+2})}{f'(x_{k+2})}
\end{align}

ニュートン法で解く逆平方根

 ニュートン法で逆平方根を近似することを考えてみましょう。

$a$の逆平方根$\frac{1}{\sqrt{a}}$の値を$x$として、$f(x)=0$の形を目指します。

\begin{align}
\frac{1}{\sqrt{a}}&=x \\
\frac{1}{a^{\frac{1}{2}}}&=x \\
\Big( \frac{1}{a^{\frac{1}{2}}}\Big) ^2&=x^2 \\
\frac{1}{a}&=x^2 \\
\frac{1}{x^2}&=a \\
x^{-2}-a&=0 \\
\end{align}

$f(x)=x^{-2}-a$として、ニュートン法の式にあてはめます。

\begin{align}
x_{k+1}&=x_{k}-\frac{f(x_{k})}{f'(x_{k})} \\
x_{k+1}&=x_{k}-\frac{x_{k}^{-2}-a}{-2x_{k}^{-3}} \\
x_{k+1}&=x_{k}-(-0.5x_{k}+0.5ax_{k}^3) \\
x_{k+1}&=x_{k}+0.5x_{k}-0.5ax_{k}^3 \\
x_{k+1}&=1.5x_{k}-0.5ax_{k}^3 \\
x_{k+1}&=x_{k}(1.5-0.5ax_{k}^2)
\end{align}

$\frac{1}{\sqrt{a}}$の値を近似していく式を求めることができました。fast sqrt algorithmではこのニュートン法の式を一回(又は二回)だけ行うことで、計算量を抑えています。

初期値

 $\frac{1}{\sqrt{a}}=x$となるような$x$の近似値はニュートン法で求めることができます。

\begin{align}
x_{k+1}&=x_{k}(1.5-0.5ax_{k}^2)
\end{align}

このとき、初期値(最初の$x_{k}$)はできるだけ$x$の値に近いほうが、ニュートン法での収束が速いです。

fast sqrt algorithmでは、この初期値を以下のようにして求めています。

  i  = 0x5f3759df - ( *(long*)&number >> 1 );

ビットシフトと減算だけです。かっこいいですね。なんの脈絡もなく0x5f3759dfという値が使用されていることから、0x5f3759dfはマジックナンバーと呼ばれたりしています。

このマジックナンバーはIEEE 754という規格にそって求められた値です。

IEEE 754 float型は32bitの浮動小数点で、1bitの符号ビット、8bitの指数ビット、23bitの仮数ビットからなります。

値(X) 符号(F) 指数(E) 仮数(T)
0.5 0 01111110 00000000000000000000000
2 0 10000000 00000000000000000000000
1809 0 10001001 11000100010000000000000
125.125 0 10000101 11110100100000000000000

表にあるように$X,F,E,T$に置き換えると、数式では以下のように書くことができます。

X=
\begin{cases}
+(2^{E-127} \times 1.T) & (F=0) \\
-(2^{E-127} \times 1.T) & (F=1) \\
\end{cases}

$\frac{1}{\sqrt{a}}=x$となるような$x$の値をこの規格のビット列から考えてみます。$a=2^{E-127} \times 1.T$として、

\begin{align}
a&=2^{E-127} \times 1.T \\
a^{-\frac{1}{2}}&=2^{-\frac{E-127}{2}} \times 1.T^{-\frac{1}{2}} & ...(1) \\
a^{-\frac{1}{2}}&=2^{E'-127} \times 1.T' & ...(2) \\
\end{align}

$E'$と$T'$の値が分かれば、$\frac{1}{\sqrt{a}}=x$となるような$x$の値をビット列から求めることができます。

$E'$を求めるには、(1)式と(2)式の2にある指数部分で方程式を作ります。

\begin{align}
E'-127&=-\frac{E-127}{2} \\
E'&=-0.5E+63.5+127 \\
E'&=-0.5E+190.5
\end{align}

0.5の倍数は1bit右シフトなので、$\frac{1}{\sqrt{a}}=x$の近似値を$a$から求めるには以下のようになります。(0x5f0000は符号部0と指数部190から。)

  x  = (0x5f0000 | (T')) - ( *(long*)&a >> 1 );

これで指数部分の近似値を求めることができました。

仮数部分である$T'$は、力技で一番誤差の小さい値を探します。マジックナンバーの完成まであともう少しです。

マジックナンバーを求める

 仮数部分$T'$は0x0000000xffffffまでの0x1000000通り考えることができます。0x1000000通りの$T'$の中で、一番誤差の小さい値を探します。

誤差は、$a$に対する正確な平方根sqrt(a)と$T'$を用いたfsqrt(a)との差の大きさです。符号部分と指数部分を固定した0x3f0000000x3fffffff$(0.5 \leqq a < 2.0)$までの値の中での最大誤差で比較します。

\min_{T'}\max_{a}|sqrt(a) - fsqrt(a)|

WebGPUで求めてみました。

main.html
main.html
<!--

main.html

-->

<!DOCTYPE html>
<html lang = "ja">
  <head>
    <title>Sparkle</title>
    <meta charset="UTF-8">

    <script src="webgpuControl.js"></script>

  </head>
  <body bgcolor="#ffffff" text="000000">

    <h3>
      debug
    </h3>

	<!-- dom操作用の箱 -->
	<div id="domBox">
	</div>

	<!-- スクリプトの実行 -->
	<script src="main.js"></script>

  </body>
</html>
main.js
main.js
/*

main.js

*/

console.log("test")

let Sr = 0
init()

async function init() {
  let num = 0
  let Fr = 1000

  let s = 0
  let f = 0xffffff

  for(let j = 0; j < 6; j++) {
    console.log("\n\n\nsf : " + s.toString(16) + " : " + f.toString(16))
    for(let i = s; i <= f; i += (0x100000 >> (j*4))) {
      console.log(i.toString(16))
      await run(0x5F000000 | i)
      if(Fr > Sr) {
        Fr = Sr
        num = i
      }
    }

    console.log("best num : " + num.toString(16) + "\n error : " + Fr.toString())
    s = num - (0x100000 >> (j*4))
    f = num + (0x100000 >> (j*4))
  }
  console.log("\n\nfinish")
  console.log("best num : " + num.toString(16) + "\n error : " + Fr.toString())

  console.log("magic number")
  await run(0x5F3759df)
  await run(0x5F375A86)
}

async function run(bnum) {
  //let stime = performance.now()

  let ren = new webgpuControl("test")
  await ren.setup()
  ren.setCode(`

  struct Output {
    data: array<f32>,
  };

  @group(0) @binding(0)
  var<storage, read_write> output : Output;

  fn sqrt_M(x: f32) -> f32 {
    var hx: f32 = 0.5 * x;
    var tmp: i32 = ` + (0x5f000000 | bnum.toString()) + `- (bitcast<i32>(x) >> 1);
    var xk: f32 = bitcast<f32>(tmp);

    xk = xk * (1.5 - (hx * xk * xk));
    //xk = xk * (1.5 - (hx * xk * xk));
    return xk * x;
  }

  @compute @workgroup_size(256,1,1)
  fn main(@builtin(global_invocation_id) global_id : vec3<u32>,
          @builtin(num_workgroups) numWork : vec3<u32>) {
    //
    var index: u32 = global_id.x + global_id.y * numWork.x * 256;
    var bitx: i32 = bitcast<i32>(0x3f000000 | index);
    var x: f32 = bitcast<f32>(bitx);
    output.data[index] = abs(sqrt(x) - sqrt_M(x));

    if(global_id.x == (0xffffff - 1)) {
      var index2: u32 = index + 1;
      var bitx2: i32 = bitcast<i32>(0x3f000000 | index);
      var x2: f32 = bitcast<f32>(bitx2);
      var t = abs(sqrt(x2) - sqrt_M(x2));
      if(t > output.data[index]){
        output.data[index] = t;
      }
    }
  }
  `)

  let buffsize = Math.pow(2,16)
  ren.setBuf(0, 256*4*buffsize, "SRC")
  ren.run(buffsize/2, 2, 1, "main")
  let obj = await ren.getBuf(0, 256*4*buffsize, 0)

  //console.log(performance.now() - stime)

  let ary = new Float32Array(obj[0])
  //console.log(ary.byteLength)
  //console.log(JSON.parse(JSON.stringify(ary)))
  let max = ary.reduce((a,b)=>{return Math.max(a,b)})

  console.log(max)

  ren.unmap(obj)
  ren.delete_device()

  Sr = max
}
webgpuControl.js
webgpuControl.js
/*

webgpuControl.js
  WebGPUのComputerShaderを使いやすくしたラッパークラス。
  exanple : setup -> setCode -> setBuf -> run -> getBuf -> unmap

*/

class webgpuControl {
  constructor(label = "") {
    this.label = label
    this.bufList = []
    this.error = undefined
  }

  async setup() {
    if (!navigator.gpu) {
      this.error = "WebGPU not supported."
      return 0
    }
    else {
      this.adapter = await navigator.gpu.requestAdapter()
      if (!this.adapter) {
        this.error = "Couldn't request WebGPU adapter."
        return 0
      }
      this.device = await this.adapter.requestDevice({label : this.label})
    }
    return 0
  }

  setCode(code = " ") {
    if(this.error === undefined) {
      this.computeShaderModule = this.device.createShaderModule({code: code, label : (this.label + " : setCode")})
    }
  }

  setBuf(bindingIndex = 0, byteSize = 0, bufTypebuf = undefined, buf = undefined) {
    if(this.error === undefined) {
      let flag;
      if(bufTypebuf === "DST") {
        flag = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_DST
      }
      else if(bufTypebuf === "SRC") {
        flag = GPUBufferUsage.STORAGE | GPUBufferUsage.COPY_SRC
      }
      else if(bufTypebuf === "UNIFORM") {
        flag = GPUBufferUsage.UNIFORM | GPUBufferUsage.COPY_DST
      }
      else {//"NON"
        flag = GPUBufferUsage.STORAGE
      }
      this.bufList[bindingIndex] = {
        binding : bindingIndex,
        resource : {
          buffer : this.device.createBuffer({
            size : byteSize, usage : flag,
            label : (this.label + " : setBuf(" + bindingIndex + ")")
          })
        }
      }

      if(buf !== undefined) {
        this.device.queue.writeBuffer(this.bufList[bindingIndex].resource.buffer, 0, buf.data, buf.offset, buf.size)
      }
    }
  }

  run(xLen = 1, yLen = 1, zLen = 1, entryPoint = "main") {
    if(this.error === undefined) {
      let computePipeline = this.device.createComputePipeline({
        layout : "auto",
        compute : {
          module : this.computeShaderModule,
          entryPoint : entryPoint
        },
        label : (this.label + " : run_createComputePipeline")
      })

      let bindGroup = this.device.createBindGroup({
        layout : computePipeline.getBindGroupLayout(0),
        entries: this.bufList,
        label : (this.label + " : run_createBindGroup")
      })
      
      let commandEncoder = this.device.createCommandEncoder({label : (this.label + " : run_createCommandEncoder")})
      let passEncoder = commandEncoder.beginComputePass({label : (this.label + " : run_beginComputePass")})
      passEncoder.setPipeline(computePipeline)
      passEncoder.setBindGroup(0, bindGroup)
      passEncoder.dispatchWorkgroups(xLen, yLen, zLen)
      passEncoder.end()
      
      this.device.queue.submit([commandEncoder.finish()])
    }
  }

  async getBuf(bindingIndex = 0, size = 0, offset = 0) {
    if(this.error === undefined) {
      let stagingBuffer = this.device.createBuffer({
        mappedAtCreation : false,
        size : size,
        usage : GPUBufferUsage.COPY_DST | GPUBufferUsage.MAP_READ,
        label : (this.label + " : getBuf(" + bindingIndex + ")")
      })
      
      let copyEncoder = this.device.createCommandEncoder({label : (this.label + " : getBuf_createCommandEncoder")})
      copyEncoder.copyBufferToBuffer(
        this.bufList[bindingIndex].resource.buffer, offset,
        stagingBuffer, 0, size
      )

      this.device.queue.submit([copyEncoder.finish()])
      await stagingBuffer.mapAsync(GPUMapMode.READ)
      let copyArrayBuffer = stagingBuffer.getMappedRange()
      
      return [copyArrayBuffer, stagingBuffer]
    }
  }

  unmap(obj = undefined) {
    if(this.error === undefined) {
      obj[1].unmap()
    }
  }

  debug() {
    return this
  }

  delete_device() {
    this.device.destroy()
  }
}

結果として、0x5f3700a0がマジックナンバーとなりました。

念のため、c++でも誤差値が正しく求められているかを確認してみました。

main.cpp
main.cpp

#include <iostream>
#include <cmath>
#include <cstdint>
#include <bit>
#include "windows.h"

float t_sqrtF(float a) {
    float hx = 0.5 * a;
    int tmp = 0x5f3700a0 - (std::bit_cast<int>(a) >> 1);
    float xk = std::bit_cast<float>(tmp);

    xk = xk * (1.5 - (hx * xk * xk));
    //xk = xk * (1.5 - (hx * xk * xk));
    return xk * a;
}

int main() {
    LARGE_INTEGER freq;//速度計測用
    QueryPerformanceFrequency(&freq);
    LARGE_INTEGER start, end;
    QueryPerformanceCounter(&start);

    float s = 0;
    uint32_t num = 0;
    for (uint32_t l = 0; l < 0x1000000; l++) {
        float t = std::sqrt(std::bit_cast<float>(0x3f000000 | l)) - t_sqrtF(std::bit_cast<float>(0x3f000000 | l));
        t = std::abs(t);
        if (l == 0)std::cout << t << std::endl;
        if (s < t) {
            s = t;
            num = l;
        }
    }

    QueryPerformanceCounter(&end);//時間計測
    double time = static_cast<double>(end.QuadPart - start.QuadPart) * 1000.0 / freq.QuadPart;

    std::cout << s << std::endl;
    std::cout << num << std::endl;


    system("pause");

    return 0;
}

なぜ0x5F3759DFとならなかったのかを調べてみると、どうやら人によって誤差の基準が違うようです。他の方が見つけたマジックナンバーも気になる方はぜひ参考文献を覗いてみてください。

プログラム例

 fast sqrt algorithmのプログラム例です。マジックナンバーの部分はお好きなものでどうぞ。

#include <bit>

float t_sqrtF(float a) {
    float hx = 0.5 * a;
    int tmp = 0x5f3700a0 - (std::bit_cast<int>(a) >> 1);
    float xk = std::bit_cast<float>(tmp);

    xk = xk * (1.5 - (hx * xk * xk));
    //xk = xk * (1.5 - (hx * xk * xk));
    return xk * a;
}

↓こちらはWGSLで書かれたものです。

  fn sqrt_M(x: f32) -> f32 {
    var hx: f32 = 0.5 * x;
    var tmp: i32 = 0x5f3700a0 - (bitcast<i32>(x) >> 1);
    var xk: f32 = bitcast<f32>(tmp);

    xk = xk * (1.5 - (hx * xk * xk));
    //xk = xk * (1.5 - (hx * xk * xk));
    return xk * x;
  }

コメントを外すと精度が上がりますが、その際はマジックナンバーの値も変わるので注意です。ニュートン法二回近似した場合のマジックナンバーをWebGPUで求めてみると、0x5f373a00となりました。

参考文献

江添亮のブログ 平方根のアルゴリズム
高速根号計算 (fast sqrt algorithm)
Best Square Root Method - Algorithm - Function (Precision VS Speed)
Fast inverse square root
滴了庵日録 高速逆平方根(fast inverse square root)のアルゴリズム解説

むすび

 調べてみると様々なマジックナンバーがあるようで、とても戸惑いました。本家本元の0x5F3759DFが一番よく見かけますね。平方根の高速アルゴリズムの紹介でした。

51
34
3

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
51
34

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?