6
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

numpy&numbaで世界最小ヒントの数独を解いた

Last updated at Posted at 2020-03-15

はじめに

Pythonで数独を解くってやり尽くされている感があるけど、このままだと自分の記憶にしか存在しないコードになってしまうので記事作成することにしました。

方針

深さ優先探索。
空欄のマスに数独の条件を満たすような数字を小さい順に仮に入れてみる。仮に入れた状態で他のマス目でも同じことをしてみる。どこかで矛盾したら直前に仮定したところを修正する。これをひたすら繰り返す。いつか矛盾が必ず起きなくなり、空欄が全て仮定の値で埋まる。それが正解。

image.png
図の引用元

実装

仮定が数独の条件を満たすか検証

配列のインデックスをタプル型で指定するよりも整数型で指定する方が実装しやすかったので、数独の配列は9×9行列を1×81ベクトルにしたもの(flatten)を想定している。indexの数え方は0〜81を左上から順に右に数えていく。上図を例にすると、左上の3×3ブロックのピンク色の数字「5, 3, 6, 9, 8」のindexはそれぞれ「0, 1, 9, 19, 20」となる。

def fillable(table, index, value):
    """計算中の行列において、正解だと仮定した値が数独の条件を満たしているか検証する。

    Parameters
    -------
        table : np.ndarray of int
            計算中の行列。2D-array。
        index : int
            仮定したマス目の場所。1行目は0~8, 2行目は9~17, ...
        value : int
            仮定した数。

    Returns
    -------
        bool
            計算中の行列において、仮定が正解である可能性があるかないか。
    """

    row = index // 9
    col = index % 9

    fillable_in_row = value not in table[row, :]
    fillable_in_col = value not in table[:, col]
    fillable_in_block = value not in table[
                                     (row // 3) * 3: (row // 3 + 1) * 3,
                                     (col // 3) * 3: (col // 3 + 1) * 3,
                                     ]

    return fillable_in_row and fillable_in_col and fillable_in_block

Pythonの基本構文しか使っていないので結構自然な実装だと思う。

仮定しまくり

def fill(table_flat):
    """深さ優先探索により、数独を解く。

    Parameters
    -------
        table_flat : np.ndarray of int
            数独の問題。1D-array

    Returns
    -------
        np.ndarray of int
            仮定が代入された行列。shape = (9, 9)
    """

    for tmp_i, tmp_val in enumerate(table_flat):

        if tmp_val == 0:
            fillable_vals = [val
                             for val in range(tmp_val + 1, 10)
                             if fillable(table_flat.reshape(9, 9), tmp_i, val)]

            for new_val in fillable_vals:
                table_flat[tmp_i] = new_val

                if fill(table_flat).all():
                    break
            else:
                table_flat[tmp_i] = 0
                break
    return table_flat.reshape(9, 9)

1行ずつ解説していく。

indexが小さいマス目から順に検討をする。

for tmp_i, tmp_val in enumerate(table_flat):

問題の空欄マス == 0として入力することを想定している。なので、以下の処理はマスが空欄の時だけ検討するということを意味する。

if tmp_val == 0:

ちょっと長めのリスト内包表記。
仮置きする値は現在の値よりも大きい数にする。しかし、数独の条件を満たさないような仮定は除去する。その検証は先ほどのfillable関数を使う。
問題の入力は9×9行列を1×81ベクトルにしたもの(flatten)を想定しているが、ルール上3×3のブロックでの抽出の必要があるのでこの時だけreshapeする。

fillable_vals = [val
                 for val in range(tmp_val + 1, 10)
                 if fillable(table_flat.reshape(9, 9), tmp_i, val)]

正解の可能性がある数をマスに代入してみる。

for new_val in fillable_vals:
    table_flat[tmp_i] = new_val

ここが肝。
まず、for/else文について説明する。for文が正常終了した場合にelse文に入る。正常終了ではない場合とは、for文がまだ終わっていないのにbreakするようなとき。ちなみに、for文が空回りした時も正常終了なのでelse文に入る。なので、ここではbreakの位置に注目すればいい。(ちなみに、Effective Python によるとfor/else分の使用は非推奨)

仮定した行列をfill関数の引数として再帰する。numpy.ndarray.all()は0でない数字が1つもなければTrueを返す。今回、問題の空欄マス == 0としてるので、all()の挙動を数独で言い換えると、空欄がなくなる、つまり、数独が解けたらfor文を抜ける。
解けていなければ、仮定したマスの値を0に戻して検討を終了する。0に戻した状態でreturnをするとそれがif fill(table_flat).all():に伝わり仮定を修正する処理が行われる。

    for ~~~  ~~~:
     ~~~  ~~~
        if fill(table_flat).all():
            break
    else:
        table_flat[tmp_i] = 0
        break
return table_flat.reshape(9, 9)

実行方法

fill関数とfillable関数を同じファイルに並べて、fill関数の引数に数独を1行にした81次ベクトルを引数に入れれば実行できる。

高速化(世界最小ヒント問題)

現状でも、普通の問題なら数ミリ秒〜数秒で解けるが、世界最小ヒントの数独を解こうとすると、いつまで経っても解が求められなかった。
image.png

現状のコードにこれを解かせようとするとパソコンがうなるだけなので、numbaによる高速化を試みる。

変更点① import

ファイルの先頭にnumbaをimport

from numba import njit, int64, bool_

変更点② デコレータをつけた

@njit(output, input)のようにして入力と出力の型を定義する。
今回1番詰まったのがここだった。numpy.ndarray.reshape()を使いたかったらint64[::1]とする必要があるらしい。(numbaの公式ドキュメントに書いてあった。ここらへん。)
本当はint64ではなくint8で良さそうだけど、上手くいかなかったのでデフォルトのままにした。(dtype=int8とか指定するのだろうか。)

# intの2次元配列とint2つを入力してboolを返す関数
@njit(bool_(int64[:, :], int64, int64))
def fillable(table, index, value):

# intの1次元配列を入力してintの2次元配列を返す関数
@njit(int64[:, :](int64[::1]))
def fill(table_flat):

変更点③ in演算を除去

numbaはPythonの構文はかなりサポートしているはずだが、fillable関数内のin演算でエラーが起きた。(fillable_in_row = value not in table[row, :]など。)
そのためfillable関数をin演算を使わずに実装し直した。

def fillable(table, index, value):

    row = index // 9
    col = index % 9
    block = lambda x: ((x // 3) * 3, (x // 3 + 1) * 3)

    same_in_areas = False
    for area in [
        table[row, :], table[:, col],
        table[block(row)[0]:block(row)[1], block(col)[0]:block(col)[1]].flatten()
    ]:
        for i in area:
            same_in_areas |= (i == value)

    return not same_in_areas

仮定したマスの縦横と3×3ブロックにの数字をiに入れて回す。そのiが仮定の値valueと一致する、つまり、数独の条件を満たさない場合があるか検証する。

コード全体

suudoku.py
def fillable(table, index, value):
    """計算中の行列において、正解だと仮定した値が数独の条件を満たしているか検証する。

    Parameters
    -------
        table : np.ndarray of int
            計算中の行列。2D-array。
        index : int
            仮定したマス目の場所。1行目は0~8, 2行目は9~17, ...
        value : int
            仮定した数。

    Returns
    -------
        bool
            計算中の行列において、仮定が正解である可能性があるかないか。
    """

    row = index // 9
    col = index % 9

    fillable_in_row = value not in table[row, :]
    fillable_in_col = value not in table[:, col]
    fillable_in_block = value not in table[
                                     (row // 3) * 3: (row // 3 + 1) * 3,
                                     (col // 3) * 3: (col // 3 + 1) * 3,
                                     ]

    return fillable_in_row and fillable_in_col and fillable_in_block


def fill(table_flat):
    """深さ優先探索により、数独を解く。

    Parameters
    -------
        table_flat : np.ndarray of int
            数独の問題。1D-array

    Returns
    -------
        np.ndarray of int
            仮定が代入された行列。shape = (9, 9)
    """

    for tmp_i, tmp_val in enumerate(table_flat):

        if tmp_val == 0:
            fillable_vals = [val
                             for val in range(tmp_val + 1, 10)
                             if fillable(table_flat.reshape(9, 9), tmp_i, val)]

            for new_val in fillable_vals:
                table_flat[tmp_i] = new_val

                if fill(table_flat).all():
                    break
            else:
                table_flat[tmp_i] = 0
                break
    return table_flat.reshape(9, 9)
suudoku_faster.py
from numba import njit, int64, bool_


@njit(bool_(int64[:, :], int64, int64))
def fillable(table, index, value):

    row = index // 9
    col = index % 9
    block = lambda x: ((x // 3) * 3, (x // 3 + 1) * 3)

    same_in_areas = False
    for area in [
        table[row, :], table[:, col],
        table[block(row)[0]:block(row)[1], block(col)[0]:block(col)[1]].flatten()
    ]:
        for i in area:
            same_in_areas |= (i == value)

    return not same_in_areas


@njit(int64[:, :](int64[::1]))
def fill(table_flat):

    for tmp_i, tmp_val in enumerate(table_flat):

        if tmp_val == 0:
            fillable_vals = [val
                             for val in range(tmp_val + 1, 10)
                             if fillable(table_flat.reshape(9, 9), tmp_i, val)]

            for new_val in fillable_vals:
                table_flat[tmp_i] = new_val

                if fill(table_flat).all():
                    break
            else:
                table_flat[tmp_i] = 0
                break
    return table_flat.reshape(9, 9)
suudoku_input.py
from suudoku import fill as fill_slow
from suudoku_faster import fill as fill_fast
import matplotlib.pyplot as plt
import numpy as np
from time import time

table =  [0, 0, 0, 8, 0, 1, 0, 0, 0,
          0, 0, 0, 0, 0, 0, 4, 3, 0,
          5, 0, 0, 0, 0, 0, 0, 0, 0,
          0, 0, 0, 0, 7, 0, 8, 0, 0,
          0, 0, 0, 0, 0, 0, 1, 0, 0,
          0, 2, 0, 0, 3, 0, 0, 0, 0,
          6, 0, 0, 0, 0, 0, 0, 7, 5,
          0, 0, 3, 4, 0, 0, 0, 0, 0,
          0, 0, 0, 2, 0, 0, 6, 0, 0]

# numbaの1周目は計算時間だけでなくコンパイル時間を含むので1回空回りさせる。
fill_fast(np.array(table).copy())

start = time()
ans = fill_fast(np.array(table).copy())
finish = time()

print("解答")
print(ans)
print()
print(f"計算時間(sec): {finish - start:5.7}")

fig, ax = plt.subplots(figsize=(6, 6))
fig.patch.set_visible(False)
ax.axis("off")

colors = np.where(table, "#F5A9BC", "#ffffff").reshape(9, 9)
picture = ax.table(cellText=ans, cellColours=colors, loc='center')
picture.set_fontsize(25)
picture.scale(1, 3)

ax.axhline(y=0.340, color="b")
ax.axhline(y=0.665, color="b")
ax.axvline(x=0.335, color="b")
ax.axvline(x=0.665, color="b")

fig.tight_layout()
plt.show()

出力

普通の問題なら0.01秒くらいなのにnumbaを持ってしても、最小ヒント問題は100秒かかってしまうのね…
ちなみにこっちの世界一難しい数独は普通に解けた。

解答
[[2 3 4 8 9 1 5 6 7]
 [1 6 9 7 2 5 4 3 8]
 [5 7 8 3 4 6 9 1 2]
 [3 1 6 5 7 4 8 2 9]
 [4 9 7 6 8 2 1 5 3]
 [8 2 5 1 3 9 7 4 6]
 [6 4 2 9 1 8 3 7 5]
 [9 5 3 4 6 7 2 8 1]
 [7 8 1 2 5 3 6 9 4]]

計算時間(sec): 101.6835

image.png

6
3
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?