66
40

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

BrainPadAdvent Calendar 2017

Day 22

数独問題を解くアルゴリズムと実装

Last updated at Posted at 2017-12-21

この記事はBrainPad Advent Calender 2017の22日目の記事です。
こんにちは、BrainpadでWebエンジニアやっています、チンバトと申します。本記事ではいくつかのアルゴリズムで数独問題を解いて見たのをまとめました。

数独問題の解き方

いろいろあると思います。普通に眼、頭、手の組み合わせで解いたり、オンラインツール使ったり、だれか解を知っている人に教えてもらったりとか。。。しかし、今回はプログラミングで解く方法について簡単に紹介したいと思います。

1. Backtracking

要はよくある深さ優先探索のことです。つまり、ありえる組み合わせを繰り返してチェックして行きます。もちろん途中で枝刈りしたりなどの工夫入れることで速度を上げることも可能です。単純に$9\times9$のフィールドに1から9までの数字を入れる組み合わせの数は$9^{81}$個(約 $2\times10^{77}$)で、その中で数独の制約を満たすのでも$約6.67\times10^{21}$(厳密には6670903752021072936960個)でまだまだ膨大な数字です。そのため、組み合わせをチェックしていくこの方法は他の方法に比べて遅かったりしますが、問題に解があれば必ず見つけられるというメリットもあります。

イメージはこれです:
alt

2. Stochastic search

数独問題を確率的に解く方法で、収束が速いですが、解が必ず見つかる保証はありません。例えば以下のように解くこともあります:

  • 全ての空いているマスに適当に数字を入れておく
  • エラーを計算する(制約に引っかかった数など)
  • 数字をシャッフルさせてエラーが減るように繰り返す

シャッフルのアルゴリズムとして遺伝的アルゴリズムなどが使われたりします。

最近お流行りのディープラーニングで数独問題を解くのも分類としてはこちらの「Stochastic search」になるかと思います。結局、Gradient Descentでエラーが小さくなるようにパラメータを探していますから。

3. Constraint programming

数独問題を制約充足問題として定義すれば、最適化の汎用ソルバーで解けますよねという方法です。この方法でも解があれば必ず見つかります。

4. Exact cover

これはConstraint programmingだと言えば、Constraint programmingなのですが、アルゴリズムが結構面白いので項目として分けています。すごく簡単に言いますと、あるセットを複数のサブセットでMECE(重複なく、もれなく)にカバーできるようなサブセットの集合を探すNP完全問題です。Donald Knuthさんが発案したAlgorithm Xなどを使えばExact cover問題の全ての解が見つかります。なので、数独問題をExact cover問題として定義することさえできればいいわけです。

実装

解き方ごとにPythonでの実装例を示します。実装ですが、もちろん全て自力で書いた訳ではなく、他者の書いたコードも参考にさせていただいたりしました。

実装方法に問わず、共通のサンプル問題として以下を使います:

s.png

ちなみに、数独好きな嫁に聞くとこの問題はそんなに難しくないらしいですよ。11分で解いてくれました。

1. ショートコーディングのBacktracking

sudoku.py
def r(a):n=9;b=3;i=a.find('0');~i or exit(a);[m in[(i-j)%n*(i/n^j/n)*(i/(n*b)^j/(n*b)|i%n/b^j%n/b)or a[j]for j in range(n**2)]or r(a[:i]+m+a[i+1:])for m in map(str,range(1,n+1))]
from sys import*;r(argv[1])

はい、終わりです。実行してみます

$> time python short_solver.py 000070605000004080200006070073000200500000004004000960080400003050200000706080000
438172695617954382295836471173649258569728134824513967982465713351297846746381529
python short_solver.py   9.66s user 0.03s system 98% cpu 9.807 total

問題を左上から右下に向けて数字として書き、数字がない所は0で埋めて表現しています。ちなみにコードはわざとショートにしていますが、展開してみるとワーストケースが$O(n^4)$の馬鹿なBacktrackingアルゴリズムであることがわかるかと思います。

2. Stochastic search: 機械学習で数独問題に挑む

私はWebエンジニアなので機械学習など詳しくなくて、同僚の機械学習エンジニアにヘルプ求めました。周りに聞ける人がいるのはとてもいいですね。
今回は以下の流れで実装しています:

  • データ準備: kaggleに100万件の数独問題と解答のデータがありましたので、そこから1万件ぐらい抜粋しました。今回は精度をそこまで重視しない雰囲気だけの実装なのでデータは少なめでやっています。
  • モデル作成: すごくシンプルなCNNにしました。1層しかなく、それがshallow learningと呼ばれるみたいです。
  • 実行: 学習済みモデルに例の問題を食わせて結果を取得します。Stochastic searchの特徴としては結果が必ず解である保証がないことです。さて解が見つかるかな?
create_model.ipynb
import csv
import numpy as np

from keras.models import Model, load_model
from keras.layers import Conv2D, Input, Concatenate, Flatten, Reshape, Dense, Activation
from keras.utils import plot_model
from keras.preprocessing.image import Iterator


def build_model():
    inp = Input((9, 9, 10))
    x1 = Conv2D(filters=10, kernel_size=(3, 3), strides=(3, 3), padding='valid', activation='relu')(inp)
    x2 = Conv2D(filters=10, kernel_size=(9, 1), padding='valid', activation='relu')(inp)
    x3 = Conv2D(filters=10, kernel_size=(1, 9), padding='valid', activation='relu')(inp)

    x1 = Flatten()(x1)
    x2 = Flatten()(x2)
    x3 = Flatten()(x3)

    x = Concatenate()([x1, x2, x3])
    x = Dense(9 * 9 * 9)(x)
    x = Reshape((9, 9, 9))(x)
    x = Activation('softmax')(x)

    model = Model(inp, x)
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    model.summary()

    return model


def read_csv(path):
    with open(path) as f:
        reader = csv.reader(f)
        next(reader)  # first row is header

        quizes = []
        solutions = []
        for quiz, solution in reader:
            quizes.append(quiz)
            solutions.append(solution)
    return quizes, solutions


def construct_board(s, is_solution=False):
    if is_solution:
        board = np.zeros((9, 9, 9), dtype=np.float32)
    else:
        board = np.zeros((9, 9, 10), dtype=np.float32)  # the 10th dim is for empty grid
    for i, ss in enumerate(s):
        board[i // 9, i % 9, int(ss) - 1] = 1.0   # map 1 to 0th, 2 to 1th, ..., 9 to 8th, and 0 to -1th
    return board


class SudokuIterator(Iterator):
    def __init__(self, csv_path, batch_size, shuffle=False, seed=0):
        quizes, solutions = read_csv(csv_path)
        self.x = np.array([construct_board(q) for q in quizes])
        self.y = np.array([construct_board(s, is_solution=True) for s in solutions])

        n = len(self.x)
        self.steps_per_epoch = n // batch_size
        super(SudokuIterator, self).__init__(n, batch_size, shuffle, seed)

    def _get_batches_of_samples(self, index_array):
        batch_x = self.x[index_array]
        batch_y = self.y[index_array]
        return batch_x, batch_y

    def _get_batches_of_transformed_samples(self, index_array):
        return self._get_batches_of_samples(index_array)  # Keras Iterator requirement

    def next(self):
        with self.lock:
            index_array = next(self.index_generator)
        return self._get_batches_of_samples(index_array)


print('building model...')
sudoku_model = build_model()
print('done.')

print('building generator...')
train_gen = SudokuIterator(csv_path='sudoku10000.csv', batch_size=64, shuffle=True)
print('done.')

sudoku_model.fit_generator(train_gen, steps_per_epoch=train_gen.steps_per_epoch, epochs=50)
sudoku_model.save('sudoku_model.h5')

こちらjupyter notebook上で実行していましたが、以下が学習結果です:

building model...
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            (None, 9, 9, 10)     0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 3, 3, 10)     910         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_2 (Conv2D)               (None, 1, 9, 10)     910         input_1[0][0]                    
__________________________________________________________________________________________________
conv2d_3 (Conv2D)               (None, 9, 1, 10)     910         input_1[0][0]                    
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 90)           0           conv2d_1[0][0]                   
__________________________________________________________________________________________________
flatten_2 (Flatten)             (None, 90)           0           conv2d_2[0][0]                   
__________________________________________________________________________________________________
flatten_3 (Flatten)             (None, 90)           0           conv2d_3[0][0]                   
__________________________________________________________________________________________________
concatenate_1 (Concatenate)     (None, 270)          0           flatten_1[0][0]                  
                                                                 flatten_2[0][0]                  
                                                                 flatten_3[0][0]                  
__________________________________________________________________________________________________
dense_1 (Dense)                 (None, 729)          197559      concatenate_1[0][0]              
__________________________________________________________________________________________________
reshape_1 (Reshape)             (None, 9, 9, 9)      0           dense_1[0][0]                    
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 9, 9, 9)      0           reshape_1[0][0]                  
==================================================================================================
Total params: 200,289
Trainable params: 200,289
Non-trainable params: 0
__________________________________________________________________________________________________
done.
building generator...
done.
Epoch 1/50
156/156 [==============================] - 6s 39ms/step - loss: 2.1754 - acc: 0.1548
Epoch 2/50
156/156 [==============================] - 5s 33ms/step - loss: 1.9703 - acc: 0.3388
Epoch 3/50
156/156 [==============================] - 5s 34ms/step - loss: 1.6940 - acc: 0.4593
Epoch 4/50
156/156 [==============================] - 6s 37ms/step - loss: 1.5140 - acc: 0.5098
...
Epoch 50/50
156/156 [==============================] - 6s 38ms/step - loss: 0.9684 - acc: 0.6565

学習の結果「acc: 0.6565」(精度)のモデルができたようです。つまり、学習したデータであれば65%位は正解しますよっと。もちろん今回はすごくシンプルにやっているので過学習したりしている可能性が大いにあります。一旦このモデルを利用して問題を解いてくれるスクリプト書きました:

ml_solver.py
import sys
import numpy as np
from keras.models import load_model


def construct_board(s, is_solution=False):
    if is_solution:
        board = np.zeros((9, 9, 9), dtype=np.float32)
    else:
        board = np.zeros((9, 9, 10), dtype=np.float32)  # the 10th dim is for empty grid
    for i, ss in enumerate(s):
        board[i // 9, i % 9, int(ss) - 1] = 1.0   # map 1 to 0th, 2 to 1th, ..., 9 to 8th, and 0 to -1th
    return board

sudoku_model = load_model('sudoku_model.h5')
answer = sudoku_model.predict(np.array([construct_board(sys.argv[1])]))

for r in range(9):
    row = ''
    for c in range(9):
        row += str(answer[0][r,c].argmax() + 1)
    print(row)

実際に実行してみます:

$>time python ml_solver.py 000070605000004080200006070073000200500000004004000960080400003050200000706080000
Using TensorFlow backend.
139878645
667344182
468356779
673868255
565827814
514827968
181485753
352266796
736283542
python ml_solver.py   3.24s user 0.69s system 64% cpu 6.044 total

結果をみればわかる通り、明らかに数独の制約が満たされてないですね。。。つまり、問題を解けませんでした。まあ、今回は学習データが少ないですし、シンプルにやっていますので雰囲気だけ伝わればいいかなと思います。ちなみに、実行時間は3.24sとなっていますが、ほとんどがモデルのロードに時間かかっています。ロード済みのモデルに解を聞いて表示するだけなら、0.004sでした。

3. Pulp使ったConstraint programming

Pulpというパッケージ使って問題を最適化問題として定義し、デフォルトのCBCソルバーで解いてみます。

pulp_solver.py
# -*- coding: utf-8 -*-
import argparse
from pulp import LpVariable, LpInteger, LpProblem, LpMinimize, LpStatus, lpSum, value


def main(inp):
    n = 9
    b = 3
    digits = [str(d + 1) for d in range(n)]
    values = rows = columns = digits
    answers = []

    choices = LpVariable.dicts("Choice", (values, rows, columns), 0, 1, LpInteger)
    boxes = [[(rows[b * i + k], columns[b * j + l]) for k in range(b) for l in range(b)] for j in range(b) for i in range(b)]

    # 問題提議
    problem = LpProblem("Solving Sudoku", LpMinimize)  # MinimizeでもMaximizeでもOK
    problem += 0, "Arbitrary Objective Function"

    # 制約追加
    for r in rows:
        for c in columns:
            problem += lpSum([choices[v][r][c] for v in values]) == 1, ""

    for v in values:
        for r in rows:
            problem += lpSum([choices[v][r][c] for c in columns]) == 1, ""

        for c in columns:
            problem += lpSum([choices[v][r][c] for r in rows]) == 1, ""

        for b in boxes:
            problem += lpSum([choices[v][r][c] for (r, c) in b]) == 1, ""

    for i in range(n**2):
        val = inp[i]
        if val != '0':
            problem += choices[str(val)][str(i/n + 1)][str(i % n + 1)] == 1, ""

    while True:
        # cbcソルバー利用
        problem.solve()
        if LpStatus[problem.status] == "Optimal":
            answers.append(''.join([v for r in rows for c in columns for v in values if value(choices[v][r][c]) == 1]))
            # 見つけた解を制約として追加
            problem += lpSum(
                [choices[v][r][c] for v in values for r in rows for c in columns if value(choices[v][r][c]) == 1]
            ) <= 80
        else:
            break

    if answers:
        # 最初の解だけ表示
        print(answers[0])


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('inp', type=str)
    args = parser.parse_args()
    main(args.inp)

問題の定義が少し面倒ですよね。実行してみます:

$>time python pulp_solver.py 000070605000004080200006070073000200500000004004000960080400003050200000706080000
438172695617954382295836471173649258569728134824513967982465713351297846746381529
python solver.py   0.23s user 0.08s system 57% cpu 0.538 total

そこそこ速いですね。

4. Exact cover問題を解くDancing links実装

Dancing linksはExact cover問題を解くAlgorithm Xを発案したKnuthさんが提案しているAlgorithm Xの実装です。この実装では、ボードを双方向循環リストで表現し、数字がないところの計算をスキップするため、速いです。また以下のPython実装では、双方向循環リストの変わりに辞書を使っていますが、アクセスされるパターンとしては同じですし、コードがシンプルで短くなっているだけです。

dancing_links_sudoku.py
import sys
from itertools import product


def solve_sudoku(grid):
    b = 3
    n = b * b
    x = ([("rc", rc) for rc in product(range(n), range(n))] +
         [("rn", rn) for rn in product(range(n), range(1, n + 1))] +
         [("cn", cn) for cn in product(range(n), range(1, n + 1))] +
         [("bn", bn) for bn in product(range(n), range(1, n + 1))])
    y = dict()
    for r, c, n in product(range(n), range(n), range(1, n + 1)):
        box_number = (r // b) * b + (c // b)
        y[(r, c, n)] = [
            ("rc", (r, c)),
            ("rn", (r, n)),
            ("cn", (c, n)),
            ("bn", (box_number, n))]
    x, y = exact_cover(x, y)
    for i, row in enumerate(grid):
        for j, n in enumerate(row):
            if n:
                select(x, y, (i, j, n))

    for solution in solve(x, y, []):
        for (r, c, n) in solution:
            grid[r][c] = n
        yield grid


def exact_cover(x, y):
    x = {j: set() for j in x}
    for i, row in y.items():
        for j in row:
            x[j].add(i)
    return x, y


def solve(x, y, solution):
    if not x:
        yield list(solution)
    else:
        c = min(x, key=lambda c: len(x[c]))
        for r in list(x[c]):
            solution.append(r)
            cols = select(x, y, r)
            for s in solve(x, y, solution):
                yield s
            deselect(x, y, r, cols)
            solution.pop()


def select(x, y, r):
    cols = []
    for j in y[r]:
        for i in x[j]:
            for k in y[i]:
                if k != j:
                    x[k].remove(i)
        cols.append(x.pop(j))
    return cols


def deselect(x, y, r, cols):
    for j in reversed(y[r]):
        x[j] = cols.pop()
        for i in x[j]:
            for k in y[i]:
                if k != j:
                    x[k].add(i)


if __name__ == "__main__":
    n = 9
    q = sys.argv[1]
    p = [map(int, list(q[i*n:(i+1)*n])) for i in range(n)]
    for solution in solve_sudoku(p):
        print(''.join([''.join(map(str, r)) for r in solution]))

こんな難しいアルゴリズムの実装を80行ぐらいで書けちゃうのはPythonの良さですよね。実行してみます。

$>time python dancing_links_sudoku.py 000070605000004080200006070073000200500000004004000960080400003050200000706080000
438172695617954382295836471173649258569728134824513967982465713351297846746381529
python dancing_links_sudoku.py   0.04s user 0.02s system 66% cpu 0.084 total

Pulp使ったのよりも速いですね。

まとめ

本記事では数独問題の解き方について紹介し、それぞれのPythonでの実装例を示しました。Stochastic以外の方法では確実に解が見つかりますが、Stochastic方法では解が見つかる保証はありませんが、もし解である場合の実行時間が短いというメリットがありますね。そのため、Stochasticな感じで改めて学習済みのモデルに入れてみて解が制約満たさなかったらExact cover問題として解くなど複数のアルゴリズムを組み合わせることで多くの問題を解く際の合計時間を減らせるかもしれません。こう言う方法は数独問題に限らず様々な問題に対しても活用できそうですね。

数独はもともと人間が頭の体操を目的にしたゲームで、それを機械に解かせて意味あるのと思っちゃう人いるかもしれませんが、機械に解かせるために私も頭の体操をしましたのでご安心ください。

最後になりますが、私は本記事で趣味として色々とアルゴリズムのことを書きましたが、普段は自社プロダクトの開発&運用している普通のWebエンジニアです。ブレインパッドでは一緒に働くWebエンジニアの仲間達を募集中ですので宣伝しておきます。


https://jobs.forkwell.com/BrainPad/jobs/2458

参考

66
40
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
66
40

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?