#最適化問題を勉強中
数学は得意なので内容はすっと入ってくるんですが、そのコードを読み解くのって難しいですね。。。
最短道問題をハンズオンするために、こちらのコードを読み解いています。
僕自身初心者なので理解するのに時間がかかるけど記録のためにメモメモ。
#コード
フルのコードはここから
##問題設定
ノード数8で0.26の確率で道(エッジと呼ぶ)が作られる問題設定。
LpProblemは引数無しだと最小化問題になる (引用はドキュメントから)
fast_gnp_random_graphメソッドのドキュメントはここから
from pulp import *
import networkx as nx
g = nx.fast_gnp_random_graph(8, 0.26, 1).to_directed()
source, sink = 0, 2 # 始点, 終点
描画用のコードはこの記事から引用
import random, numpy as np, pandas as pd, networkx as nx, matplotlib.pyplot as plt
from itertools import chain, combinations
from pulp import *
def draw(g):
"""描画"""
nx.draw_networkx_labels(g, pos=pos)
nx.draw_networkx_nodes(g, node_color='w', pos=pos)
nx.draw_networkx_edges(g, pos=pos)
plt.show()
先ほどの問題設定を描画するとこんな感じ。
fast_gnp_random_graphメソッドの第二引数を1(すべてのノードを道でつなぐ)にすると
##変数設定と目的関数のセット
この記事では0から1の間の変数とだけ定義しているから、整数とは限らなそう。
LpVaribaleのcat引数をIntegerにすれば整数問題にもできる様子。
参考はpulpのドキュメント
r = list(enumerate(g.edges()))
m = LpProblem() # 数理モデル
x = [LpVariable('x%d'%k, lowBound=0, upBound=1) for k, (i, j) in r] # 変数(路に入るかどうか)
m += lpSum(x) # 目的関数。xをぜえんぶ足す
ちなみに
list(enumerate(g.edges())) #エッジ番号, (始点、終点)
>>>[(0, (0, 1)),
(1, (1, 0)),
(2, (1, 4)),
(3, (1, 6)),
(4, (1, 7)),
(5, (2, 5)),
(6, (3, 5)),
(7, (3, 6)),
(8, (4, 1)),
(9, (5, 2)),
(10, (5, 3)),
(11, (6, 1)),
(12, (6, 3)),
(13, (7, 1))]
[LpVariable('x%d'%k, lowBound=0, upBound=1) for k, (i, j) in r]
>>>[x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13]
##制約条件足す
ここが難しかった。もっと勉強せねば・・・
for nd in g.nodes():
m += lpSum(x[k] for k, (i, j) in r if i == nd) \
== lpSum(x[k] for k, (i, j) in r if j == nd) + {source:1, sink:-1}.get(nd, 0) # 制約
まず
lpSum(x[k] for k, (i, j) in r if i == nd))
#ndというnodeから出るエッジ(道)の総和。1からでる道は4本でx1 + x2 + x3 +x4
>>>x0
x1 + x2 + x3 + x4
x5
x6 + x7
x8
x10 + x9
x11 + x12
x13
そして
lpSum(x[k] for k, (i, j) in r if j == nd)
#ndというnodeへ行くエッジの総和
>>>x1
x0 + x11 + x13 + x8
x9
x10 + x12
x2
x5 + x6
x3 + x7
x4
経路において、通過するノードは出るノードから出る本数と入る本数が一致するはず。
始点の場合は出る本数のほうが一本多く、終点の場合は入る本数のほうが一本多いはず。
なので
lpSum(x[k] for k, (i, j) in r if i == nd) \
== lpSum(x[k] for k, (i, j) in r if j == nd) + {source:1, sink:-1}.get(nd, 0)
なお、辞書のgetメソッドはキーを探すためのメソッド。この記事を参照。
終点始点の場合分けをしており、左辺(出る本数)と右辺(入る本数)に差をつけるための項。
##解く
valueはfrom pulp import * の*の中にあるモジュールのよう。
入った数と出た数の関係のみを参考にしているので、同じ場所は何回でも通ってよい&行ってないノードの存在を許すという問題。
m.solve()
print([(i, j) for k, (i, j) in r if value(x[k]) > 0.5])
#最後に
楽しい!