はじめに
Pythonで数独を解くってやり尽くされている感があるけど、このままだと自分の記憶にしか存在しないコードになってしまうので記事作成することにしました。
方針
深さ優先探索。
空欄のマスに数独の条件を満たすような数字を小さい順に仮に入れてみる。仮に入れた状態で他のマス目でも同じことをしてみる。どこかで矛盾したら直前に仮定したところを修正する。これをひたすら繰り返す。いつか矛盾が必ず起きなくなり、空欄が全て仮定の値で埋まる。それが正解。
実装
仮定が数独の条件を満たすか検証
配列のインデックスをタプル型で指定するよりも整数型で指定する方が実装しやすかったので、数独の配列は9×9行列を1×81ベクトルにしたもの(flatten)を想定している。index
の数え方は0〜81を左上から順に右に数えていく。上図を例にすると、左上の3×3ブロックのピンク色の数字「5, 3, 6, 9, 8」のindex
はそれぞれ「0, 1, 9, 19, 20」となる。
def fillable(table, index, value):
"""計算中の行列において、正解だと仮定した値が数独の条件を満たしているか検証する。
Parameters
-------
table : np.ndarray of int
計算中の行列。2D-array。
index : int
仮定したマス目の場所。1行目は0~8, 2行目は9~17, ...
value : int
仮定した数。
Returns
-------
bool
計算中の行列において、仮定が正解である可能性があるかないか。
"""
row = index // 9
col = index % 9
fillable_in_row = value not in table[row, :]
fillable_in_col = value not in table[:, col]
fillable_in_block = value not in table[
(row // 3) * 3: (row // 3 + 1) * 3,
(col // 3) * 3: (col // 3 + 1) * 3,
]
return fillable_in_row and fillable_in_col and fillable_in_block
Pythonの基本構文しか使っていないので結構自然な実装だと思う。
仮定しまくり
def fill(table_flat):
"""深さ優先探索により、数独を解く。
Parameters
-------
table_flat : np.ndarray of int
数独の問題。1D-array
Returns
-------
np.ndarray of int
仮定が代入された行列。shape = (9, 9)
"""
for tmp_i, tmp_val in enumerate(table_flat):
if tmp_val == 0:
fillable_vals = [val
for val in range(tmp_val + 1, 10)
if fillable(table_flat.reshape(9, 9), tmp_i, val)]
for new_val in fillable_vals:
table_flat[tmp_i] = new_val
if fill(table_flat).all():
break
else:
table_flat[tmp_i] = 0
break
return table_flat.reshape(9, 9)
1行ずつ解説していく。
indexが小さいマス目から順に検討をする。
for tmp_i, tmp_val in enumerate(table_flat):
問題の空欄マス == 0
として入力することを想定している。なので、以下の処理はマスが空欄の時だけ検討するということを意味する。
if tmp_val == 0:
ちょっと長めのリスト内包表記。
仮置きする値は現在の値よりも大きい数にする。しかし、数独の条件を満たさないような仮定は除去する。その検証は先ほどのfillable
関数を使う。
問題の入力は9×9行列を1×81ベクトルにしたもの(flatten)を想定しているが、ルール上3×3のブロックでの抽出の必要があるのでこの時だけreshapeする。
fillable_vals = [val
for val in range(tmp_val + 1, 10)
if fillable(table_flat.reshape(9, 9), tmp_i, val)]
正解の可能性がある数をマスに代入してみる。
for new_val in fillable_vals:
table_flat[tmp_i] = new_val
ここが肝。
まず、for/else
文について説明する。for
文が正常終了した場合にelse
文に入る。正常終了ではない場合とは、for
文がまだ終わっていないのにbreak
するようなとき。ちなみに、for
文が空回りした時も正常終了なのでelse
文に入る。なので、ここではbreak
の位置に注目すればいい。(ちなみに、Effective Python によるとfor/else
分の使用は非推奨)
仮定した行列をfill
関数の引数として再帰する。numpy.ndarray.all()
は0でない数字が1つもなければTrue
を返す。今回、問題の空欄マス == 0
としてるので、all()
の挙動を数独で言い換えると、空欄がなくなる、つまり、数独が解けたらfor
文を抜ける。
解けていなければ、仮定したマスの値を0に戻して検討を終了する。0に戻した状態でreturn
をするとそれがif fill(table_flat).all():
に伝わり仮定を修正する処理が行われる。
for ~~~ 略 ~~~:
~~~ 略 ~~~
if fill(table_flat).all():
break
else:
table_flat[tmp_i] = 0
break
return table_flat.reshape(9, 9)
実行方法
fill
関数とfillable
関数を同じファイルに並べて、fill
関数の引数に数独を1行にした81次ベクトルを引数に入れれば実行できる。
高速化(世界最小ヒント問題)
現状でも、普通の問題なら数ミリ秒〜数秒で解けるが、世界最小ヒントの数独を解こうとすると、いつまで経っても解が求められなかった。
現状のコードにこれを解かせようとするとパソコンがうなるだけなので、numbaによる高速化を試みる。
変更点① import
ファイルの先頭にnumbaをimport
from numba import njit, int64, bool_
変更点② デコレータをつけた
@njit(output, input)
のようにして入力と出力の型を定義する。
今回1番詰まったのがここだった。numpy.ndarray.reshape()
を使いたかったらint64[::1]
とする必要があるらしい。(numbaの公式ドキュメントに書いてあった。ここらへん。)
本当はint64
ではなくint8
で良さそうだけど、上手くいかなかったのでデフォルトのままにした。(dtype=int8
とか指定するのだろうか。)
# intの2次元配列とint2つを入力してboolを返す関数
@njit(bool_(int64[:, :], int64, int64))
def fillable(table, index, value):
# intの1次元配列を入力してintの2次元配列を返す関数
@njit(int64[:, :](int64[::1]))
def fill(table_flat):
変更点③ in
演算を除去
numbaはPythonの構文はかなりサポートしているはずだが、fillable
関数内のin
演算でエラーが起きた。(fillable_in_row = value not in table[row, :]
など。)
そのためfillable
関数をin
演算を使わずに実装し直した。
def fillable(table, index, value):
row = index // 9
col = index % 9
block = lambda x: ((x // 3) * 3, (x // 3 + 1) * 3)
same_in_areas = False
for area in [
table[row, :], table[:, col],
table[block(row)[0]:block(row)[1], block(col)[0]:block(col)[1]].flatten()
]:
for i in area:
same_in_areas |= (i == value)
return not same_in_areas
仮定したマスの縦横と3×3ブロックにの数字をi
に入れて回す。そのi
が仮定の値value
と一致する、つまり、数独の条件を満たさない場合があるか検証する。
コード全体
def fillable(table, index, value):
"""計算中の行列において、正解だと仮定した値が数独の条件を満たしているか検証する。
Parameters
-------
table : np.ndarray of int
計算中の行列。2D-array。
index : int
仮定したマス目の場所。1行目は0~8, 2行目は9~17, ...
value : int
仮定した数。
Returns
-------
bool
計算中の行列において、仮定が正解である可能性があるかないか。
"""
row = index // 9
col = index % 9
fillable_in_row = value not in table[row, :]
fillable_in_col = value not in table[:, col]
fillable_in_block = value not in table[
(row // 3) * 3: (row // 3 + 1) * 3,
(col // 3) * 3: (col // 3 + 1) * 3,
]
return fillable_in_row and fillable_in_col and fillable_in_block
def fill(table_flat):
"""深さ優先探索により、数独を解く。
Parameters
-------
table_flat : np.ndarray of int
数独の問題。1D-array
Returns
-------
np.ndarray of int
仮定が代入された行列。shape = (9, 9)
"""
for tmp_i, tmp_val in enumerate(table_flat):
if tmp_val == 0:
fillable_vals = [val
for val in range(tmp_val + 1, 10)
if fillable(table_flat.reshape(9, 9), tmp_i, val)]
for new_val in fillable_vals:
table_flat[tmp_i] = new_val
if fill(table_flat).all():
break
else:
table_flat[tmp_i] = 0
break
return table_flat.reshape(9, 9)
from numba import njit, int64, bool_
@njit(bool_(int64[:, :], int64, int64))
def fillable(table, index, value):
row = index // 9
col = index % 9
block = lambda x: ((x // 3) * 3, (x // 3 + 1) * 3)
same_in_areas = False
for area in [
table[row, :], table[:, col],
table[block(row)[0]:block(row)[1], block(col)[0]:block(col)[1]].flatten()
]:
for i in area:
same_in_areas |= (i == value)
return not same_in_areas
@njit(int64[:, :](int64[::1]))
def fill(table_flat):
for tmp_i, tmp_val in enumerate(table_flat):
if tmp_val == 0:
fillable_vals = [val
for val in range(tmp_val + 1, 10)
if fillable(table_flat.reshape(9, 9), tmp_i, val)]
for new_val in fillable_vals:
table_flat[tmp_i] = new_val
if fill(table_flat).all():
break
else:
table_flat[tmp_i] = 0
break
return table_flat.reshape(9, 9)
from suudoku import fill as fill_slow
from suudoku_faster import fill as fill_fast
import matplotlib.pyplot as plt
import numpy as np
from time import time
table = [0, 0, 0, 8, 0, 1, 0, 0, 0,
0, 0, 0, 0, 0, 0, 4, 3, 0,
5, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 7, 0, 8, 0, 0,
0, 0, 0, 0, 0, 0, 1, 0, 0,
0, 2, 0, 0, 3, 0, 0, 0, 0,
6, 0, 0, 0, 0, 0, 0, 7, 5,
0, 0, 3, 4, 0, 0, 0, 0, 0,
0, 0, 0, 2, 0, 0, 6, 0, 0]
# numbaの1周目は計算時間だけでなくコンパイル時間を含むので1回空回りさせる。
fill_fast(np.array(table).copy())
start = time()
ans = fill_fast(np.array(table).copy())
finish = time()
print("解答")
print(ans)
print()
print(f"計算時間(sec): {finish - start:5.7}")
fig, ax = plt.subplots(figsize=(6, 6))
fig.patch.set_visible(False)
ax.axis("off")
colors = np.where(table, "#F5A9BC", "#ffffff").reshape(9, 9)
picture = ax.table(cellText=ans, cellColours=colors, loc='center')
picture.set_fontsize(25)
picture.scale(1, 3)
ax.axhline(y=0.340, color="b")
ax.axhline(y=0.665, color="b")
ax.axvline(x=0.335, color="b")
ax.axvline(x=0.665, color="b")
fig.tight_layout()
plt.show()
出力
普通の問題なら0.01秒くらいなのにnumbaを持ってしても、最小ヒント問題は100秒かかってしまうのね…
ちなみにこっちの世界一難しい数独は普通に解けた。
解答
[[2 3 4 8 9 1 5 6 7]
[1 6 9 7 2 5 4 3 8]
[5 7 8 3 4 6 9 1 2]
[3 1 6 5 7 4 8 2 9]
[4 9 7 6 8 2 1 5 3]
[8 2 5 1 3 9 7 4 6]
[6 4 2 9 1 8 3 7 5]
[9 5 3 4 6 7 2 8 1]
[7 8 1 2 5 3 6 9 4]]
計算時間(sec): 101.6835