LoginSignup
11
0

More than 1 year has passed since last update.

Goで平方根を求める、の一歩先

Last updated at Posted at 2022-12-14

この記事は、and factory.inc Advent Calendar 202215日目 の記事です。
昨日は@cpp0302【IDEA】【GoLand】エラーハンドリングのライブテンプレートを作成するでした。

この話は普段の業務とはほぼ関係のないちょっとしたパズルです。リラックスして読んでいただければと思います!

平方根を求めるのは簡単?

突然ですがGoで平方根を求めたいとします。正確には整数値を渡して平方根を超えない最大の整数を返す関数を作りたいとします。どうすれば良いでしょうか?

ちょっと慣れた方であればmathパッケージにあるのではないかと想像できるでしょうし、「Go 平方根」で検索すればmath.Sqrtという関数にはすぐに行き着くのではと思います。math.Sqrtはfloat64を受け取ってfloat64を返す関数なので変換をかける必要がありますが、それを踏まえてこんな感じで関数が作れるかと思います。

package main

import (
	"fmt"
	"math"
)

func main() {
	nums := []int{
		1,
		4,
		9,
		10,
		20,
		100,
	}
	for _, v := range nums {
		fmt.Printf("getSqrt(%d)=>%d \n", v, getSqrt(v))
	}
}

func getSqrt(i int) int {
	return int(math.Sqrt(float64(i)))
}

結果も正しそうです

getSqrt(1)=>1 
getSqrt(4)=>2 
getSqrt(9)=>3 
getSqrt(10)=>3 
getSqrt(20)=>4 
getSqrt(100)=>10 

ここまでは難しくはないかと思います。しかしこのプログラムがいかなる状況でも正しく結果を返すでしょうか?ということを考えると少し難しい問題となります。実際にある値を入れると正しく動かなくなります。さて何を入れれば良いでしょうか?ぜひここでスクロールを止めて考えてみてください。

トラップその1

以下を見ていただければ答えがわかるかと思います

101*101=10201なので10200は100になり、10201は101になります。

getSqrt(10200)=>100
getSqrt(10201)=>101
getSqrt(10202)=>101

同様に100000001*100000001=10000000200000001なので10000000200000000は100000000になり、10000000200000001は100000001になるはずが

getSqrt(10000000200000000)=>100000001
getSqrt(10000000200000001)=>100000001
getSqrt(10000000200000002)=>100000001

そうはなりません

大きい数だと問題が出るんじゃないか、float64の変換が怪しいんじゃないかと思われた方は本質をつかんでいるかと思います。float64での値も合わせてデバッグ出力すると

getSqrt(10000000200000000)=>100000001 10000000200000000.000000
getSqrt(10000000200000001)=>100000001 10000000200000000.000000
getSqrt(10000000200000002)=>100000001 10000000200000002.000000

そもそもfloat64に直した時点で値がズレていることがわかるかと思います。整数を浮動小数点に直した時点で丸められることによる誤差が生じますし、math.Sqrtの結果もfloat64での近似値となるので欲しかった答えを得ることはできません

math.Sqrtを使わずに求めるには?

時間さえかけて良いのであればmath.Sqrtを使わなくても実現できます。愚直に1から順番に調べていけば良いです。

package main

import (
	"fmt"
)

func main() {
	nums := []int{
		1,
		4,
		9,
		10,
		20,
		100,
	}
	for _, v := range nums {
		fmt.Printf("getSqrt(%d)=>%d \n", v, getSqrt(v))
	}
}

func getSqrt(i int) int {
	o := 0
	for j := 1; j <= i; j++ {
		if j*j <= i {
			o = j
		} else {
			break
		}
	}
	return o
}

ただこれだと、数が多いと辛いです。以下のように大きい数を大量に渡すとtimeoutになってしまいました。

より高速に行う方法の一つとして二分探索という手法が利用できます。これは平方根より小さいですか?という質問を繰り返してその結果により答えを半分づつ範囲を絞っていくような動きをしています。このようにすると64bitの整数だと最大64回の試行で答えを得ることができるため、全部調べるよりも圧倒的に早く調べることができます。

package main

import (
	"fmt"
)

func main() {
	nums := []int{
		1,
		4,
		9,
		10,
		20,
		100,
	}
	for _, v := range nums {
		fmt.Printf("getSqrt(%d)=>%d \n", v, getSqrt(v))
	}
}

func getSqrt(i int) int {
	return bs(1, i, func(c int) bool {
		return c*c <= i
	})
}

func bs(ok, ng int, f func(int) bool) int {
	if !f(ok) {
		return -1
	}
	if f(ng) {
		return ng
	}

	for ok-ng > 1 || ok-ng < -1 {
		mid := (ok + ng) / 2

		if f(mid) {
			ok = mid
		} else {
			ng = mid
		}
	}

	return ok
}
getSqrt(1)=>1 
getSqrt(4)=>2 
getSqrt(9)=>3 
getSqrt(10)=>3 
getSqrt(20)=>4 
getSqrt(100)=>10 

ここまでの結果は正しいです。ただ、上記のプログラム残念ながら問題があります。
それはなんでしょうか?またスクロールを止めて考えてみてください。

トラップその2

以下を見ていただければ答えがわかるかと思います

package main

import (
	"fmt"
)

func main() {

	nums := []int{
		3037000498,
		3037000499,
		3037000500,
	}
	for _, v := range nums {
		fmt.Printf("%d \n", v*v)
	}
}
9223372024852248004 
9223372030926249001 
-9223372036709301616 

64bitの範囲であればintのとり得る値は -9223372036854775808 ~ 9223372036854775807 となります。3037000500以上の数をお互い掛け合わせると9223372036854775807を超えるので結果がおかしくなります。またしても大きい数が原因ですが、今回は浮動小数点の誤差が原因ではなくオーバーフローが原因です。

実際に以下のnumsを先程のプログラムに渡してあげると

getSqrt(3037000498)=>55108 
getSqrt(3037000499)=>55108 
getSqrt(3037000500)=>3037000500 

3037000500以上の結果がおかしくなります。

以上を踏まえて、以下のように書けばおそらく64bit環境下では正しい答えが返ってくるかと思います。

package main

import (
	"fmt"
)

func main() {
	nums := []int{
		10200,
		10201,
		10202,
		1002000,
		1002001,
		1002002,
		100020000,
		100020001,
		100020002,
		10000200000,
		10000200001,
		10000200002,
		1000002000000,
		1000002000001,
		1000002000002,
		100000020000000,
		100000020000001,
		100000020000002,
		10000000200000000,
		10000000200000001,
		10000000200000002,
	}
	for _, v := range nums {
		fmt.Printf("getSqrt(%d)=>%d \n", v, getSqrt(v))
	}
}

func getSqrt(i int) int {
	return bs(1, min(i, 3037000499), func(c int) bool {
		return c*c <= i
	})
}

func bs(ok, ng int, f func(int) bool) int {
	if !f(ok) {
		return -1
	}
	if f(ng) {
		return ng
	}

	for ok-ng > 1 || ok-ng < -1 {
		mid := (ok + ng) / 2

		if f(mid) {
			ok = mid
		} else {
			ng = mid
		}
	}

	return ok
}

func min(a, b int) int {
	if a < b {
		return a
	}
	return b
}

全て正しい結果を得ることができました!

getSqrt(10200)=>100 
getSqrt(10201)=>101 
getSqrt(10202)=>101 
getSqrt(1002000)=>1000 
getSqrt(1002001)=>1001 
getSqrt(1002002)=>1001 
getSqrt(100020000)=>10000 
getSqrt(100020001)=>10001 
getSqrt(100020002)=>10001 
getSqrt(10000200000)=>100000 
getSqrt(10000200001)=>100001 
getSqrt(10000200002)=>100001 
getSqrt(1000002000000)=>1000000 
getSqrt(1000002000001)=>1000001 
getSqrt(1000002000002)=>1000001 
getSqrt(100000020000000)=>10000000 
getSqrt(100000020000001)=>10000001 
getSqrt(100000020000002)=>10000001 
getSqrt(10000000200000000)=>100000000 
getSqrt(10000000200000001)=>100000001 
getSqrt(10000000200000002)=>100000001 

教訓

普段の業務で平方根が必要になることはありませんし、ましてやこの誤差が問題になることはないでしょう。ただ、この話はちょっとした教訓があるかと思います。

一つは挙動を理解することの大切さです。今回だとプリミティブな整数型や浮動小数点数型の挙動を理解していれば間違いを防げたり罠にハマるのを防げるかと思います。

もう一つはデータのとりうる値です。今回だと大きいな値がネックでしたが、どんなデータが来ても正しく動くようにするのは難易度が上がりますし、それを踏まえてとり得る値をちゃんと制御したりバリデーションする、連携するデータはシンプルで整合性が取れた状態を保つというのは大事なことなのではと思います。

無理矢理業務の話に繋げましたが、どんなものでも正確に挙動を理解しデータを管理した上で良いシステムを作っていければと思います!

明日のAdvent Calendarの記事もお楽しみに!

11
0
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
11
0