AtCoder231 H - Minimum Coloringを対象としたサンプルプログラムを以下に示します。
格子状のマス目にコストが設定されていて、それぞれの行と列から一つ以上のマス目を
選択する場合のコストの合計の最小値を求める問題です。
pulpを使用した場合のサンプルプログラムを以下に示します。
本番環境ではpulpは使用できません。
H,W,N = list(map(int, input().split()))
abc_list = []
for _ in range(N):
a,b,c = list(map(int, input().split()))
abc_list.append((a-1,b-1,c))
import pulp
problem = pulp.LpProblem(sense=pulp.LpMinimize)
from collections import defaultdict
var_list = []
coef_list = []
row_dict = defaultdict(list)
col_dict = defaultdict(list)
for i in range(N):
a,b,c = abc_list[i]
n = f'v_{a}_{b}'
v = pulp.LpVariable(n, 0, 1, 'Integer')
var_list.append(v)
coef_list.append(c)
row_dict[a].append(v)
col_dict[b].append(v)
problem += pulp.lpDot(coef_list, var_list)
for v in row_dict.values():
problem += pulp.lpSum(v) >= 1
for v in col_dict.values():
problem += pulp.lpSum(v) >= 1
result = problem.solve(pulp.PULP_CBC_CMD(msg=False))
ans = int(pulp.value(problem.objective))
print(ans)
# print(pulp.LpStatus[result])
# for i in range(N):
# a,b,c = abc_list[i]
# v = var_list[i].value()
# print(a,b,c,int(v))
networkxを使用した場合のサンプルプログラムを以下に示します。
実装方法は以下の解説に従っています。
最小重み辺被覆を最小重み完全マッチングに変換して解きます。
H,W,N = list(map(int, input().split()))
abc_list = []
for _ in range(N):
a,b,c = list(map(int, input().split()))
abc_list.append((a-1,b-1,c))
INF = max([c for _,_,c in abc_list])
ac_min = [INF] * H
bc_min = [INF] * W
for a,b,c in abc_list:
ac_min[a] = min(ac_min[a], c)
bc_min[b] = min(bc_min[b], c)
import networkx as nx
G = nx.Graph()
# U=a, W=b+H+W, U'=a+H+W+W, W'=b+H
for a in range(H):
G.add_node(a)
G.add_node(a+H+W+W)
for b in range(W):
G.add_node(b+H)
G.add_node(b+H+W)
from collections import defaultdict
c_dict = defaultdict(int)
for a,b,c in abc_list:
G.add_edge(a,b+H+W,weight=c)
c_dict[(a,b+H+W)] = c
G.add_edge(b+H,a+H+W+W,weight=c)
c_dict[(b+H,a+H+W+W)] = c
for a in range(H):
G.add_edge(a,a+H+W+W,weight=ac_min[a]*2)
c_dict[(a,a+H+W+W)] = ac_min[a]*2
for b in range(W):
G.add_edge(b+H,b+H+W,weight=bc_min[b]*2)
c_dict[(b+H,b+H+W)] = bc_min[b]*2
top_nodes = list(range(H+W))
result = nx.bipartite.minimum_weight_full_matching(G,top_nodes=top_nodes)
ans = 0
for a,b in result.items():
if a > b: continue
ans += c_dict[(a,b)]
print(ans//2)