コンピュータには大昔から組み込まれているそうだが,自分で導出してみる.
2次元のベクトル$(x,y) \in \mathbb R^2$の長さは
$$ d(x,y) = \sqrt{x^2 + y^2} . $$
である.しかしルートの計算は大変なので,どうにかルート以外の計算でこの近似を求められないか考えたい.
次の関数を考える:
$$ \hat d(x, y) = a \min (|x|,|y|) + b \max (|x|,|y|) . $$
ここの$a, b$をうまいこと考えれば$d(x,y)$の良い近似になるだろう.
求解
$\min(x,y)$と$\max(x,y)$の組み合わせは全てで8つあるが,今回は$0 \le y \le x$の領域だけを考える.
この領域を考えるだけで十分である.$x$と$y$の順序や符号を変えれば,全ての$(x, y)$をこの領域に引き込んできて,同じように扱える(例:$d(x,y) = d(-y, x)$).
この領域においては$\min=y, \max=x$なので,
$$ \hat d(x, y) = a y + b x $$
となる.
最適な$a, b$を求めるための最適化問題を次のように定める:
$$
\min_{a,b}
J =
\iint_{0 \le y \le x} dx ~ dy ~
\frac{1}{2} \left[
d(x,y) - \hat d(x, y)
\right] ^2 .
$$
$J$を求めると:
$$
J =
L^4 \left[
\frac{a}{12} - \frac{\sqrt{2} a}{6}
- \frac{\sqrt{2}b}{8} + \frac{a^2}{24}
+ \frac{b^2}{8}
- \frac{b \ln(2\sqrt{2}+1)}{8}
+ \frac{1}{6}
\right] .
$$
ただし$L$は計算のために定めた$x$の上限(後で消えるから大丈夫).
この$a$および$b$微分は:
$$\begin{cases}
\partial_a J =
L^4\left(
a/12 + b/8 - \sqrt{2}/6 + 1/12
\right),\\
\partial_b J =
L^4 \left(
a/8 + b/4 - \log(\sqrt{2} + 1)/8 - \sqrt{2}/8
\right) .
\end{cases}$$
これらの偏微分が0になる時の$a,b$は:
$$\begin{cases}
a = 5\sqrt{2} - 3\ln(\sqrt{2} + 1) - 4 \approx 0.42695 , \\
b = 2\ln(\sqrt{2} + 1) - 2\sqrt{2} + 2 \approx 0.93432
\end{cases}$$
と求まった!
テスト
| $x$ | $y$ | 答え $d(x,y)$ | 近似 $\hat d(x,y)$ | Error(%) |
|---|---|---|---|---|
| 1.0 | 1.0 | 1.414 | 1.361 | 3.744 |
| 2.0 | 1.0 | 2.236 | 2.296 | 2.662 |
| 4.0 | 3.0 | 5.000 | 5.018 | 0.362 |
| 10.0 | 6.0 | 11.662 | 11.905 | 2.084 |
| 0.5 | 0.2 | 0.539 | 0.553 | 2.606 |
| 123.0 | 45.0 | 130.973 | 134.134 | 2.413 |
| 100.0 | 0.0 | 100.000 | 93.432 | 6.568 |
だいたいのケースで誤差率を5%程度までに抑えられていることが分かる.領域の縁近くだとちょっと誤差が大きくなるかも.
今後の展望(?)
他の説明変数を入れてみてもいいかもしれない.例えば$|x| + |y|$など.また,中央寄り用・縁寄り用など,$|x|/|y|$比に応じて$a, b$を切り替えてもいいかもしれない.
コード
MATLABを使用した.
%% 求解
syms a b x y L
actual = sqrt(x^2 + y^2);
approx = a*y + b*x;
J = int(int( ...
0.5 * (actual - approx)^2, ...
y, 0, x), x, 0, L)
dJ_da = simplify(diff(J, a))
dJ_db = simplify(diff(J, b))
sol = solve([dJ_da == 0, dJ_db == 0], [a, b], ReturnConditions=false);
a_sol = sol.a
b_sol = sol.b
a_num = double(vpa(a_sol, 10));
b_num = double(vpa(b_sol, 10));
fprintf('a = %.5f\n', a_num);
fprintf('b = %.5f\n', b_num);
%% テスト
tests = [ ...
1, 1;
2, 1;
4, 3;
10, 6;
0.5, 0.2;
123, 45;
100, 0;
];
for i = 1:size(tests,1)
xt = tests(i,1);
yt = tests(i,2);
gt = sqrt(xt^2 + yt^2);
pred = a_num*yt + b_num*xt;
err = abs(gt - pred)/gt*100;
fprintf('%5.1f | %5.1f | %8.3f | %10.3f | %9.3f\n', xt, yt, gt, pred, err);
end