LoginSignup
4
3

More than 3 years have passed since last update.

optunaに数独を解かせてみた

Posted at

はじめに

動機

何となくの思い付きです。

数独を人力以外で解きたかったら、線形計画法やHopfield networkを使えば解けるのはよく知られたことだと思います。
しかし、数独のルールを知らなくても試行回数でごり押しをすれば解けるのではないかと思ったので、実際に試してみました。

数独とは

パズルゲームの一種です。
詳細はwikipediaを参照してください。数独

oputunaとは

optuna
Preferred Networks社によって開発されたハイパーパラメータを最適化するためのpythonのライブラリです。
ハイパーパラメータの探索アルゴリズムはTPEを利用しており、これはHyperOptと同じです。

細かな説明は省略して一言でいうと、最適化の対象がブラックボックスでも、いい感じに目的関数を最適化してくれるということです。

方法

前述の通り、optunaを利用します。
そのために正解に近づくほど値が小さくなるように目的関数を設定します。

目的関数

目的関数はルール違反の個数を数えることにします。
ここでは違反が少なくなることと正解に近づくことは同義です。
明示的に正解は知りませんが、目的関数の値をどんどん最小化していくと、最終的には違反が0、正解に到達します。

def evaluate(answer):
    # 数独のループに違反してい個数を返す。正解なら評価は0になる
    # answerは9x9の2次元配列で、各要素に1~9の数字が入っている
    tmp = np.reshape(answer, [3, 3, 3, 3])
    loss = np.sum((
        # 各列で違反している個数
        np.sum([np.count_nonzero(np.logical_not(np.any(answer == i, axis=0))) for i in range(9)]),
        # 各行で違反している個数
        np.sum([np.count_nonzero(np.logical_not(np.any(answer == i, axis=1))) for i in range(9)]),
        # 3x3の領域ごとで違反している個数
        np.sum([np.count_nonzero(np.logical_not(np.any(tmp == i, axis=(1, 3)))) for i in range(9)]),
    ))
    return loss

違反の個数は、重複した数字の個数を数えるのではなく、逆に登場しなかった数字の個数を数えています。

ルールを満たしているならば、各条件において1~9が1回ずつ登場します。
ルール違反で重複がある場合は、別のどれかの数字の登場回数が減って0になっているので、こちらを数えます。

重複の場合は2~9回まで様々なパターンがありますが、登場回数が減る場合は0回しかないので、こちらを数えた方がシンプルになると思います。

optunaの呼び出しなど(コード全体)

解きたい問題はソースコードにハードコーディングしました。
基本的にはoptunaのドキュメントのコードのそのままです。

from itertools import product

import numpy as np
import optuna
import click


"""
問題はhttps://ja.wikipedia.org/wiki/%E6%95%B0%E7%8B%ACの例題の画像より
 -----------------        
|5|3| | |7| | | | |
|-+-+-+-+-+-+-+-+-|
|6| | |1|9|5| | | |
|-+-+-+-+-+-+-+-+-|
| |9|8| | | | |6| |
|-+-+-+-+-+-+-+-+-|
|8| | | |6| | | |3|
|-+-+-+-+-+-+-+-+-|
|4| | |8| |3| | |1|
|-+-+-+-+-+-+-+-+-|
|7| | | |2| | | |6|
|-+-+-+-+-+-+-+-+-|
| |6| | | | |2|8| |
|-+-+-+-+-+-+-+-+-|
| | | |4|1|9| | |5|
|-+-+-+-+-+-+-+-+-|
| | | | |8| | |7|9|
 -----------------        
"""
# 予め値が入っている部分
preset = {'p00': 5, 'p01': 3, 'p04': 7,
          'p10': 6, 'p13': 1, 'p14': 9, 'p15': 5,
          'p21': 9, 'p22': 8, 'p27': 6,
          'p30': 8, 'p34': 6, 'p38': 3,
          'p40': 4, 'p43': 8, 'p45': 3, 'p48': 1,
          'p50': 7, 'p54': 2, 'p58': 6,
          'p61': 6, 'p66': 2, 'p67': 8,
          'p73': 4, 'p74': 1, 'p75': 9, 'p78': 5,
          'p84': 8, 'p87': 7, 'p88': 9}


def evaluate(answer):
    # 上記の通り


def objective(trial):
    candidate = (1, 2, 3, 4, 5, 6, 7, 8, 9)

    answer = np.empty([9, 9], dtype=np.uint8)
    for i, j in product(range(9), repeat=2):
        key = 'p{}{}'.format(i, j)
        if key in preset:
            answer[i, j] = preset[key]
        else:
            answer[i, j] = trial.suggest_categorical(key, candidate)

    return evaluate(answer)


def run(n_trials):
    study_name = 'sudoku'
    study = optuna.create_study(study_name=study_name, storage='sqlite:///sudoku.db', load_if_exists=True)
    study.optimize(objective, n_trials=n_trials)

    show_result(study.best_params, study.best_value)

    df = study.trials_dataframe()
    df.to_csv('tpe_result.csv')


def show_result(best_params, best_value):
    for i in range(9):
        for j in range(9):
            key = 'p{}{}'.format(i, j)
            if key in preset:
                print('*{:1d}'.format(preset[key]), end='')
            else:
                print('{:2d}'.format(best_params[key]), end='')
        print('')
    print('loss: {}'.format(best_value))


@click.command()
@click.option('--n-trials', type=int, default=1000)
def cmd(n_trials):
    run(n_trials)


def main():
    cmd()


if __name__ == '__main__':
    main()

実験結果

期待してここまで読んでくださった方には申し訳ないです。

全然解けませんでした。

想定よりも1回あたりの評価にかかる時間が遥かに長くて試行回数が全然足りていません。
試行回数が0回で約7秒、200回で約25秒です。

このまま続けても正解にたどり着くには数十万~数百万回の試行回数が必要と思うので、現実的には達成不可能だと思います。

最後に、一応は200回終了時の出力を示します。

*5*3 5 5*7 2 6 1 2
*6 7 8*1*9*5 7 3 5
 4*9*8 6 3 7 8*6 4
*8 1 4 7*6 9 7 6*3
*4 3 9*8 5*3 2 7*1
*7 8 9 3*2 1 5 4*6
 3*6 1 6 7 9*2*8 7
 5 7 3*4*1*9 8 1*5
 4 7 6 2*8 4 3*7*9
loss: 73.0

左側に*がある数字は、解答のヒントとして最初からある数字です。

まとめ

optuna(TPEアルゴリズム)で数独は解けません。
試行回数でごり押しをする予定でしたが、計算時間が想定よりもはるかに長くてごり押しができませんでした。

optuna(TPEアルゴリズム)は試行回数が増えるとサンプリングにかかる時間が増えます。

今後の予定

今回の失敗の原因はTPEアルゴリズムの計算時間が長かったことなので、別のアルゴリズムを使えばこの問題は解決できると思います。
そもそもTPEアルゴリズムは条件付き分布を考えないので、数独を解くのにはかなり不向きなアルゴリズムではあります。

次回は条件付き分布からサンプリングを行うアルゴリズムで試したいと思います。
そのため、数独のルールを理解した上で解くことになります。
今回のように数独のルールを知らずに正解に辿り着くのは無謀な試みだったのでしょう。

予定は完全に未定です。
アルゴリズムはsimulated annealing系になると思いますが、これから調べるので投稿時期は未定です。

4
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
3