概要
Streamlit で実住所に対する巡回セールスマン問題をGoogleMapに出力するアプリにて紹介したStreamlit アプリについて,今回は,Streamlit in Snowflake(SiS)を用いて実装する.大きな違いは,インターネット疎通可能なローカルのStreamlitの場合,Google Map API の接続は特に意識しなくてよいが,Snowflake を用いる場合,Snowflake 側でGoogle Map に接続するための外部ネットワークの設定を行う必要がある.これはSnowflake のアーキテクチャがそのように設計されるためである.
Google Map API への接続についてはSnowflake からGoogle Map API をたたきに行く方法にて解説している.
UI
入力
出力
Streamlit in Snowflake(SiS)のバグのようなものがいくつか見つかったため,別に記事にてまとめたいと思う.今回は,folium を使った地図の出力が動作しなかったため,pydeck を用いて地図の出力を行っている.また,pydeck を使う上でもバグのようなものがあり,テキストレイヤの出力が動作しなかった.
サンプルコード
import streamlit as st
import pydeck as pdk
from snowflake.snowpark.context import get_active_session
import pandas as pd
import numpy as np
import itertools
import json
import time
# Snowflake セッション取得
session = get_active_session()
# **全列挙型アルゴリズム**(Brute Force)
def solve_tsp_brute_force(dist_matrix):
locations_count = len(dist_matrix)
min_distance = float('inf')
min_route = None
for perm in itertools.permutations(range(locations_count)):
route_distance = sum(dist_matrix[perm[i], perm[i + 1]] for i in range(locations_count - 1))
route_distance += dist_matrix[perm[-1], perm[0]]
if route_distance < min_distance:
min_distance = route_distance
min_route = perm
return min_route, min_distance
# **近傍法**(Greedy Algorithm)
def solve_tsp_greedy(dist_matrix):
locations_count = len(dist_matrix)
unvisited = set(range(locations_count))
current_location = 0
route = [current_location]
total_distance = 0
unvisited.remove(current_location)
while unvisited:
nearest_neighbor = min(unvisited, key=lambda x: dist_matrix[current_location, x])
route.append(nearest_neighbor)
total_distance += dist_matrix[current_location, nearest_neighbor]
current_location = nearest_neighbor
unvisited.remove(current_location)
total_distance += dist_matrix[route[-1], route[0]]
return route, total_distance
# 住所を緯度・経度に変換
def geocode_address(address):
result = session.sql(f"SELECT <DATABASE_NAME>.<SCHEMA_NAME>.geocode_address('{address}')").collect()
if result:
column_name = list(result[0].as_dict().keys())[0]
geo_data = json.loads(result[0][column_name])
return geo_data[0], geo_data[1]
return None, None
# 住所間の移動時間を取得
def get_travel_time_matrix(addresses):
addr_list = ", ".join([f"'{addr}'" for addr in addresses])
result = session.sql(f"SELECT <DATABASE_NAME>.<SCHEMA_NAME>.get_travel_time_matrix(ARRAY_CONSTRUCT({addr_list}))").collect()
if result:
column_name = list(result[0].as_dict().keys())[0]
raw_matrix = result[0][column_name]
if raw_matrix is None:
st.error("移動時間データが取得できませんでした。")
return None
try:
return np.array(json.loads(raw_matrix))
except Exception as e:
st.error(f"Error parsing travel matrix: {e}")
return None
return None
# Streamlit UI
st.title("巡回セールスマン問題(TSP)")
# 住所の入力
addresses_input = st.text_area("住所を改行区切りで入力してください", height=200)
# アルゴリズム選択
algo_choice = st.radio("アルゴリズムを選択してください:", ("全列挙型アルゴリズム", "近傍法"))
# 地図表示ボタン
if st.button("最適ルートを表示"):
if addresses_input.strip():
addresses = addresses_input.strip().split("\n")
locations = []
for address in addresses:
lat, lon = geocode_address(address)
if lat is not None and lon is not None:
locations.append((lat, lon))
else:
st.error(f"住所 {address} の緯度・経度を取得できませんでした。")
break
if locations:
dist_matrix = get_travel_time_matrix(addresses)
if dist_matrix is not None:
st.text("計算中...お待ちください")
start_time = time.time()
if algo_choice == "全列挙型アルゴリズム":
min_route, min_distance = solve_tsp_brute_force(dist_matrix)
elif algo_choice == "近傍法":
min_route, min_distance = solve_tsp_greedy(dist_matrix)
end_time = time.time()
st.write(f"計算時間: {end_time - start_time:.2f} 秒")
st.write(f"最短経路の移動時間: {min_distance / 60:.2f} 分")
# 訪問順を出力
visit_order = [addresses[i] for i in min_route]
st.write("訪問順序:")
for idx, addr in enumerate(visit_order, 1):
st.write(f"{idx}. {addr}")
# 地図データの作成
locations_route = [locations[i] for i in min_route]
locations_route.append(locations_route[0]) # 最後にスタート地点に戻る
# マーカー表示用のデータ
points_data = pd.DataFrame(
[{"lat": lat, "lon": lon} for lat, lon in locations_route]
)
# 訪問順のルート線を作成
route_data = pd.DataFrame(
[{"start": [start[1], start[0]], "end": [end[1], end[0]]} for start, end in zip(locations_route[:-1], locations_route[1:])]
)
# マーカーのレイヤー
points_layer = pdk.Layer(
"ScatterplotLayer",
data=points_data,
get_position="[lon, lat]",
get_radius=100,
get_color=[255, 0, 0],
pickable=True,
)
# 訪問順のルート線のレイヤー
line_layer = pdk.Layer(
"LineLayer",
data=route_data,
get_source_position="start",
get_target_position="end",
get_color=[0, 0, 255],
get_width=3
)
# 地図の中心を設定
map_center = locations_route[0] if locations_route else [35.6895, 139.6917] # デフォルト: 東京
view_state = pdk.ViewState(latitude=map_center[0], longitude=map_center[1], zoom=12, pitch=0)
# 地図の背景をmap_style="mapbox://styles/mapbox/streets-v11"に設定
map_deck = pdk.Deck(
layers=[points_layer, line_layer],
initial_view_state=view_state,
map_style="mapbox://styles/mapbox/streets-v11"
)
st.pydeck_chart(map_deck)
else:
st.error("移動時間の取得に失敗しました。APIキーや住所の確認をしてください。")
else:
st.warning("住所を入力してください。")