1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

SnowflakeAdvent Calendar 2024

Day 21

"Streamlit in Snowflake(SiS)" で巡回セールスマン問題を解くアプリを作ってみた

Posted at

やること

Snowflakeで巡回セールスマン問題を解いてみるで扱った問題について,Streamlit in Snowflake(SiS)で問題作成から解の探索(アルゴリズムによっては最適解の保証はない),解の出力,経路の可視化までを行うアプリを作成する.

  • 巡回する都市の数をパラメータとして設定
  • 各都市間の距離はランダムに生成
  • 解の探索アルゴリズムは単純な全探索アルゴリズムと最近傍法を用いたアルゴリズムの2つから選択
    • 単純な全探索アルゴリズムは特に何の工夫もしていない全探索アルゴリズムで,厳密な最適解が得られるが,計算時間が大きくなる.拠点数2桁あたりから一気に計算時間が大きくなる.
    • 最近傍法を用いたアルゴリズムは,近傍操作を行い解を探索するメタヒューリスティクス解法で,厳密な最適解が得られる保証がないが,計算量が小さい.拠点数が3桁でも現実的な時間で解が得られる.

サンプルコード

以下,Python on Snowflake で動作するサンプルコードである.
Snowflake のデータベースへのアクセスは行っていないが,応用すれば,Snowflake のデータベースからデータを入力し,結果を書き込むことも可能である.

import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
import math
import itertools

# --- TSP Solver Functions ---
def calculate_total_distance(path, distance_matrix):
    """経路の総距離を計算"""
    total_distance = 0
    for i in range(len(path) - 1):
        total_distance += distance_matrix[path[i], path[i + 1]]
    total_distance += distance_matrix[path[-1], path[0]]  # 最後に戻る距離
    return total_distance

def simple_search(distance_matrix):
    """単純全列挙法でTSPを解く"""
    n = len(distance_matrix)
    best_path = None
    best_distance = float('inf')

    for perm in itertools.permutations(range(n)):
        current_distance = calculate_total_distance(perm, distance_matrix)
        if current_distance < best_distance:
            best_path = perm
            best_distance = current_distance

    return list(best_path), best_distance


def nearest_neighbor(distance_matrix):
    """最近傍法でTSPを解く"""
    n = len(distance_matrix)
    unvisited = set(range(n))
    current_city = 0
    unvisited.remove(current_city)
    path = [current_city]

    while unvisited:
        next_city = min(unvisited, key=lambda city: distance_matrix[current_city][city])
        unvisited.remove(next_city)
        path.append(next_city)
        current_city = next_city

    return path, calculate_total_distance(path, distance_matrix)


def plot_tsp_and_create_dataframe(cities, best_path, city_names, distance_matrix):
    """TSPの経路を地図上にプロットし、データフレームを作成"""
    ordered_cities = cities[best_path + [best_path[0]]]  # 最初の都市に戻る
    x, y = ordered_cities[:, 0], ordered_cities[:, 1]
    
    fig, ax = plt.subplots(figsize=(10, 8))
    ax.plot(x, y, 'o-', markersize=10, label="Path")  # 経路をプロット
    for i, name in enumerate(city_names):
        ax.text(cities[i, 0], cities[i, 1], name, fontsize=12, ha='center')
    ax.set_title("Traveling Salesman Problem Solution")
    ax.set_xlabel("X Coordinate")
    ax.set_ylabel("Y Coordinate")
    ax.legend()
    ax.grid()

    records = []
    total_distance = 0
    for i in range(len(best_path)):
        var1 = city_names[best_path[i]]
        var2 = city_names[best_path[(i + 1) % len(best_path)]]
        val = distance_matrix[best_path[i], best_path[(i + 1) % len(best_path)]]
        total_distance += val
        records.append([var1, var2, val])

    records.append(["opt", "opt", total_distance])
    df = pd.DataFrame(records, columns=["var1", "var2", "val"])
    return fig, df

# --- Streamlit UI ---
st.title("Traveling Salesman Problem Solver with Multiple Algorithms")

# ユーザー入力: 都市数
num_cities = st.sidebar.number_input("都市数を入力してください", min_value=2, max_value=100, value=5)

# サイドバー: アルゴリズム選択
algorithm = st.sidebar.selectbox("アルゴリズムを選択してください", ["Simple Search", "Nearest Neighbor"])

if st.button("都市を生成"):
    city_names = [f"City{i+1}" for i in range(num_cities)]
    cities = np.random.rand(num_cities, 2) * 100
    distance_matrix = np.linalg.norm(cities[:, None, :] - cities[None, :, :], axis=-1)

    st.write("### 都市の座標")
    city_data = pd.DataFrame(cities, columns=["X Coordinate", "Y Coordinate"], index=city_names)
    st.dataframe(city_data)

    # アルゴリズムに応じてTSPを解く
    if algorithm == "Simple Search":
        best_path, best_distance = simple_search(distance_matrix)
    elif algorithm == "Nearest Neighbor":
        best_path, best_distance = nearest_neighbor(distance_matrix)

    # 結果の表示
    fig, df_result = plot_tsp_and_create_dataframe(cities, best_path, city_names, distance_matrix)

    st.pyplot(fig)
    st.write("### 最適経路")
    st.dataframe(df_result)

2つのアルゴリズムによる結果

実行するたびに拠点の距離がランダムに生成される作りにしている関係で,同一の問題での比較になっていない点,留意いただきたい.単純な全探索アルゴリズムでは,拠点数が10を超えたあたりから計算時間が一気に大きくなるため,拠点数9までの結果を以下に示している.最近傍法を用いたアルゴリズムでは,拠点数を100としても数秒で結果が出力できるが,厳密な最適解となっていない点がわかる.また,拠点数が大きくなると,明らかに最適解から遠ざかっていることがわかる(初期解や入力データによって,良い解が得られる場合もある).

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?