Help us understand the problem. What is going on with this article?

量子アニーリングで数独を解く

概要

  • 数独はグラフ彩色問題とみなすことができる.
  • グラフ彩色問題は QUBO にすることができる.
  • そこで数独を頑張って QUBO に置き換え,
  • D-Wave 社の QUBO solver (QBSolv)を使って解く.

Jupyter Notebook

パッケージのインストール

  • QBSolv
pip install dwave-qbsolv
  • Networkx
pip install networkx

準備

  • 数独のルールに従い準備をする
  • 数独では縦,横,3x3 ブロックに属するマスと異なる数字を入れる必要がある
  • まずはインポートなど
from dwave_qbsolv import QBSolv
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
n=9 # num of rows
  • マスのインデックスidx
  • それを 9x9 の行列にしたgridを用意する
idxs=np.arange(n**2)
grid=np.arange(n**2).reshape(n,n)
  • 3x3 のブロックを表すblockを用意する
    • 同じブロックに属するマスには,同じ数字は入らない
# list of indices of blocks
block=np.stack([grid[i*3:(i+1)*3,j*3:(j+1)*3].flatten() for i in range(3) for j in range(3)])
  • 同じ数字が入らないグループ(unit)をまとめたunitlistを用意する
    • gridを使うことで,行・列の unit を表現できる
# list of indices of row, col, block
unitlist=np.concatenate([grid, grid.T,block])
  • それぞれのマスが属する unit の lookup table unitsを用意する
    • units[i]は,第 i マスの行,列,ブロック unit を表す 3x9 行列からなる
# list of units. shape (81, 3, 9)
def get_units(s):
    return unitlist[np.isin(unitlist,s).sum(-1).astype(bool)]
units=np.array([get_units(s) for s in range(n**2)])
  • unit を共有するマスを "peer" と呼ぶ
    • それぞれのマスには,20 個の "peer" があることになる
  • peerの lookup table peersを用意する
    • peers[i]は第 i マスの peer のリストとなる
# list of peers
def get_peers(s):
    a=np.unique(get_units(s))
    idx=a!=s
    return a[idx].tolist()

# peers[i] indicates peers of i-th square.
peers=np.array([get_peers(s) for s in idxs])
  • 少し長くなったが,準備おわり

例題

  • 9x9=81 マスの初期状態を表している.
    • 0 は空白マスを示す.
a= '003020600900305001001806400008102900700000008006708200002609500800203009005010300'
  • パーサー
    • 9x9 の ndarray にして返す
digits = np.arange(n)
def parse(grid):
    assert len(grid) == 81
    return np.array([int(c)-1 if c in (digits+1).astype(str) else -1 for c in grid])

数独のグラフ表現

  • グラフで表現するには,81 個のノードをそれぞれの peer とつないでやればよい
    • 辺でつながったノード同士は同じ数字を入れることができない
  • つまり数独は 81 ノード,$81 \times 20\ /\ 2 = 810$エッジのグラフ彩色問題とみなせる

Networkx で描画してみる

  • グラフ準備
g=nx.Graph()
  • エッジを加える
for i in idxs:
    for j in peers[i]:
        g.add_edge(i,j)
  • 描画
    • きれいに並べるために pos を用意している
theta=np.arange(n**2)
x=np.cos(2*np.pi*theta/n**2)
y=np.sin(2*np.pi*theta/n**2)
pos={i:(x[i],y[i]) for i in range(n**2)}
nx.draw_networkx(g,pos=pos)

qiita_14_0.png

  • なかなか気持ち悪い

QUBO 行列を作る

  • 81x9 = 729 の qubit $q_{ij}$で状態を表現する
    • 第 i マスに 数字 j が入るとき,$q_{ij}=1$となり,それ以外の場合は 0
  • 3 種類の制約条件を適用する
    1. 1 つのマスには 1 つの数字しか入らない $$\sum_j q_{ij} = 1 \tag{1}$$
    2. peer のマスとは色が異なる
    3. 初期条件

第一の制約

  • 式(1)を qubo に変換するには,(1)が満たされるときのみ最小値を取るようなものを考える
    • これは下のようなエネルギー関数を作ると良い $$Ha_i=(1-\sum_j q_{ij})^2$$
    • $q_{ij} \in {0,1}$から$q_{ij}^2=q_{ij}$に注意すると,上式は下のようになる $$Ha_i=- q_{ii}^2 + \sum_{j\neq k}q_{ij}q_{ik} + 1$$
    • 定数項は除いてもよいが,収束判定に使える (今回は数独のルールベースでの判定をするので使わない)
  • 今回,qubit $q_{ij}$を 81x9 の行列として表現しているので,QUBO は(81, 9, 81, 9)の ndarray にするとわかりやすい
    • つまり$q_{ij}q_{kl}$の係数は$qubo_{ijkl}$とする
    • あとで QBSolv にわたす形に変換する
alpha=1. # 適当な係数
qubo=np.zeros((n**2,n,n**2,n)) # quboの準備
qubo[range(n**2),:,range(n**2),:] += alpha*(np.ones((n,n))-2*np.eye(n))

第二の制約:隣接ノードを異なる数字にする

  • i マスと j マスが同じ unit にいるとき,$q_{ik}\neq q_{jk}$となってほしい
  • これは$q_{ik}q_{jk}$をエネルギー関数とすれば良い
    • マスに入る数字が異なれば 0,同じならば 1 となるので,異なるときのみ最小値を取る
# さっき作ったグラフをつかう
for e in g.edges:
    i,j=e
    qubo[i,range(n),j,range(n)]+=1

第三の制約:初期条件

  • i マスの初期値が j ならば,$-q_{ij}$をエネルギー関数とすれば良い
    • 初期値を満たすときのみ最小値-1 を取る
# 問題aの初期値
init_values=parse(a)

beta=10.
i = np.where(init_values>=0)[0]:
j=init_values[i]
qubo[i,j,i,j]=-beta

QUBO 行列を適当な形式に変換する

  • QBSolv にわたすときは,{(i,j): Q[i,j]}のように辞書で渡す
    • つまり {"qubit i, j" : "その 2 次係数"}という対応を表してやれば良い
  • ますはquboを行列_Qに変形する
idx=np.unravel_index(np.arange(n**2*n),(n**2,n)) # idxは後で使う
_Q=qubo[idx[0],idx[1],:,:]
_Q=_Q[:,idx[0],idx[1]] # shape (729, 729)
  • QBSolv にわたすために辞書Qにする
Q={}
# ゼロでない係数のみ渡す
for i in zip(*np.nonzero(_Q)):
    Q[i]=_Q[i]

QBSolv で解く

  • 先程作った Qを渡せば良い
    • 一行で解いてくれる!
res=QBSolv().sample_qubo(Q)
  • 結果を解釈する
def get_answer(init_values, dct):
    x=np.zeros((n**2,n))
    y=np.zeros(n**3)
    y[list(dct.keys())]=np.array(list(dct.values()))
    x[idx]=y
    values=x @ np.arange(n)
    values[init_values>=0]=init_values[init_values>=0]
    return values

answer=get_answer(init_values,list(res.samples())[0])
  • 結果判定
def check_answer(answer):
    # 埋められていないマス(=-1)があるならば不正解
    if np.min(answer)<0:
        return False
    # すべてのマスがpeersと異なるならば正解
    return (answer[peers]-answer[:,None] == 0).sum() == 0
print(check_answer(answer))
True

とけた!

結果の出力

def display(values):
    x=values.reshape(n,n)
    line = "-"*19
    for r in range(n):
        print(*list(str(int(x[r,c])+1)+('|' if c == 2 or c == 5 else '') for c in range(n)))
        if r == 2 or r==5: print(line)
display(answer)
4 8 3| 9 2 1| 6 5 7
9 6 7| 3 4 5| 8 2 1
2 5 1| 8 7 6| 4 9 3
-------------------
5 4 8| 1 3 2| 9 7 6
7 2 9| 5 6 4| 1 3 8
1 3 6| 7 9 8| 2 4 5
-------------------
3 7 2| 6 8 9| 5 1 4
8 1 4| 2 5 3| 7 6 9
6 9 5| 4 1 7| 3 8 2

おわりに

Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away