0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

倒立振子と強化学習のメモ

Posted at
Prompt Name
2026-01-25_rl_inverted_pendulum_tutor_and_colab_builder_v1

Purpose
強化学習(RL)で倒立振子(Inverted Pendulum / Cart-Pole / Pendulum)を学ぶために、
(1) 力学モデル(数式)→(2) MDP定式化→(3) 学習アルゴリズム選定→(4) 実装→(5) 評価・可視化
までを一貫して“再現可能”にまとめる。

Role
あなたは「制御工学 + 強化学習」の講師兼実装エンジニアである。
曖昧語を避け、定義→式→実装の順で説明する。
未確定は未確定と明示し、推測で断定しない。

User Context
- 対象読者:高校〜大学初年レベル + Python初学〜中級
- 目的:倒立振子を題材に、RLの本質(価値・方策・クレジット割当・探索)を理解する
- 実行環境:Google Colab想定(GPUはあれば使うが必須ではない)

Task
1) 問題設定の選択肢を提示し、以降は選んだ1つに固定して進める
   - (A) Cart-Pole(離散行動) / (B) Pendulum(連続行動) / (C) InvertedPendulum(MuJoCo系)
   ※ユーザー指定がなければ(A)を採用。

2) 力学モデル(最小限でよい)を提示
   - 状態ベクトル x の定義(例:θ, θ_dot, cart位置, cart速度)
   - 運動方程式(プレーンテキスト数式)
   - 線形化(必要なら:平衡点周りの近似)とその意味
   - 物理パラメータ(m, M, l, g など)の単位を必ず併記

3) MDPとして定式化
   - 状態 S, 行動 A, 報酬 R, 終端条件, 割引率 γ
   - 報酬設計の意図(“立てる” vs “安全に抑える”)
   - 観測と真の状態の違い(必要なら)

4) アルゴリズム選定(理由つき)
   - (A) 離散:DQN / Double DQN / Dueling / Prioritized のどれを使うか
   - (B) 連続:DDPG / TD3 / SAC / PPO のどれを使うか
   - 「サンプル効率」「安定性」「実装難度」の観点で選ぶ

5) 実装(Colabで即実行可能)
   - 必須構造:PARAM_INIT(設定一元管理)→ 入力 → 計算 → 出力
   - 必須ライブラリ:numpy / matplotlib / ipywidgets / numba(必要なら)/ gymnasium
   - 乱数 seed 明示、try-except 必須
   - 学習ログ:episode return, success率, 損失(可能なら)
   - 可視化:学習曲線、状態(θなど)時系列、成功例の動画保存(可能なら)
   - UI:学習回数・学習率・γ・探索率などをスライダーで変更できるようにする

6) 評価・検証
   - 学習前後の比較(同条件での平均return)
   - 一般化チェック(初期角度や外乱を変える)
   - 失敗モード(報酬設計/探索不足/発散)の典型例と対策

Output Requirements
- 出力は「コピペでそのまま使える完成形」
- 章立て(番号付き)で提示
- 数式はプレーンテキスト(LaTeX記法は使ってもよいが、必ずプレーンテキスト併記)
- コードは省略禁止(Colabで1本で動く形)
- 図の軸ラベル・タイトルは英語
- コメントは日英併記(日本語→英語の順)
- 追加でユーザーに聞く質問は“最小限”にし、基本はデフォルトで進める

Safety / Ethics
- 実機の危険操作や破壊的行為を助長する指示はしない
- 出力は教育・研究目的に限定した説明にする

Start Now
上記に従い、まず「(A) Cart-Pole(離散)を採用」して、
1章〜6章の順に、数式→MDP→アルゴリズム→Colabコード→評価までを一括で出力せよ。
# Program Name: rl_cartpole_dqn_colab_allinone_v2.py
# Creation Date: 20260125
# Purpose: Reproducible CartPole DQN (Gymnasium) in Google Colab with pinned deps, auto-restart, UI, logging, plots, and video.

# ============================================================
# 0. SETUP (Pinned install + auto runtime restart) / セットアップ(版固定+自動再起動)
# ============================================================
# 日本語: NumPy ABI不整合(dtype size changed)を避けるため、依存を版固定で入れ直し、必要なら再起動する。
# English: Pin dependencies to avoid NumPy binary-compat issues and restart runtime if needed.

import os, sys, subprocess, textwrap, signal

SETUP_MARK = "/content/rl_setup_done_cartpole_v2.txt"

PINNED = {
    "numpy": "1.26.4",
    "matplotlib": "3.8.4",
    "ipywidgets": "8.1.1",
    "numba": "0.59.1",
    "gymnasium": "0.29.1",
    "imageio": "2.34.1",
    "imageio-ffmpeg": "0.5.1",
}

def _run(cmd):
    subprocess.check_call(cmd)

def _pip(args):
    _run([sys.executable, "-m", "pip"] + args)

def setup_and_maybe_restart():
    # Colab判定 / Detect Colab
    in_colab = ("google.colab" in sys.modules) or ("COLAB_GPU" in os.environ) or (os.path.exists("/content"))
    if not in_colab:
        return  # ローカルは再起動を強制しない / Do not force restart outside Colab

    if os.path.exists(SETUP_MARK):
        return

    print("=== Installing pinned dependencies (first run) ===")
    _pip(["install", "-q", "--upgrade", "pip"])

    # 既存の不整合を消すために一度アンインストール / Uninstall potentially conflicting packages
    _pip(["uninstall", "-y", "numpy", "gymnasium", "gymnasium-classic-control", "pyglet", "pygame"])
    _pip(["install", "-q",
          f"numpy=={PINNED['numpy']}",
          f"matplotlib=={PINNED['matplotlib']}",
          f"ipywidgets=={PINNED['ipywidgets']}",
          f"numba=={PINNED['numba']}",
          f"gymnasium=={PINNED['gymnasium']}",
          f"gymnasium[classic-control]=={PINNED['gymnasium']}",
          f"imageio=={PINNED['imageio']}",
          f"imageio-ffmpeg=={PINNED['imageio-ffmpeg']}",
    ])

    # torchはColabに通常入っているが、無い場合だけ / Torch is usually preinstalled in Colab
    try:
        import torch  # noqa
    except Exception:
        _pip(["install", "-q", "torch"])

    with open(SETUP_MARK, "w") as f:
        f.write("ok")

    print("\n=== Runtime restart required ===")
    print("Restarting now to finalize binary compatibility...")
    os.kill(os.getpid(), signal.SIGKILL)

setup_and_maybe_restart()

# ============================================================
# 1. IMPORTS / インポート
# ============================================================
import time
import math
import random
import datetime
from dataclasses import dataclass
from typing import Dict, Tuple, List, Optional

import numpy as np
import matplotlib.pyplot as plt

import gymnasium as gym

import ipywidgets as widgets
from IPython.display import display, clear_output, HTML

from numba import njit

import torch
import torch.nn as nn
import torch.optim as optim

# Colab widget manager / Colabのウィジェット有効化
try:
    from google.colab import output as colab_output
    colab_output.enable_custom_widget_manager()
except Exception:
    pass

# ============================================================
# 2. PARAM_INIT (Single source of truth) / 設定一元管理
# ============================================================
PARAM_INIT: Dict[str, object] = {
    # --- Environment / 環境 ---
    "ENV_ID": "CartPole-v1",
    "MAX_EPISODE_STEPS": 500,
    "SEED": 42,

    # --- DQN / 学習設定 ---
    "EPISODES": 600,
    "GAMMA": 0.99,
    "LR": 1e-3,
    "BATCH_SIZE": 128,
    "REPLAY_SIZE": 50_000,
    "LEARNING_STARTS": 1_000,
    "TRAIN_FREQ": 1,
    "TARGET_UPDATE": 1_000,
    "GRAD_CLIP_NORM": 10.0,

    # --- Exploration / 探索 ---
    "EPS_START": 1.0,
    "EPS_END": 0.05,
    "EPS_DECAY_STEPS": 50_000,

    # --- Network / NN ---
    "HIDDEN_SIZE": 256,

    # --- Device / デバイス ---
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",

    # --- Evaluation / 評価 ---
    "EVAL_EPISODES": 20,
    "SUCCESS_SCORE": 475.0,
    "VIDEO_RECORD": True,
    "VIDEO_DIR": "/content/videos_cartpole",

    # --- Plot / 可視化 ---
    "PLOT_EVERY": 20,
    "PLOT_MODE": "return",  # "return" or "loss"
    "SAVE_PLOTS": True,
}

# ============================================================
# 3. UTILITIES / ユーティリティ
# ============================================================
def now_tag() -> str:
    return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

def set_global_seed(seed: int) -> None:
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def print_env_info(params: Dict[str, object]) -> None:
    print("=== Runtime Info ===")
    print(f"Time: {datetime.datetime.now().isoformat(timespec='seconds')}")
    print(f"Env: {params['ENV_ID']}")
    print(f"Device: {params['DEVICE']}")
    print(f"python: {sys.version.split()[0]}")
    print(f"torch: {torch.__version__}")
    print(f"gymnasium: {gym.__version__}")
    print(f"numpy: {np.__version__}")
    print(f"matplotlib: {plt.matplotlib.__version__}")

def print_param_table(params: Dict[str, object]) -> None:
    keys = [
        "ENV_ID","SEED","EPISODES","GAMMA","LR","BATCH_SIZE","REPLAY_SIZE",
        "LEARNING_STARTS","TRAIN_FREQ","TARGET_UPDATE","EPS_START","EPS_END",
        "EPS_DECAY_STEPS","HIDDEN_SIZE","DEVICE","EVAL_EPISODES","SUCCESS_SCORE"
    ]
    print("=== PARAM_INIT (selected) ===")
    for k in keys:
        print(f"{k:>16s}: {params[k]}")

@njit
def moving_average_numba(x: np.ndarray, window: int) -> np.ndarray:
    n = x.size
    if window <= 1:
        return x.copy()
    out = np.empty(n, dtype=np.float64)
    s = 0.0
    for i in range(n):
        s += x[i]
        if i >= window:
            s -= x[i - window]
            out[i] = s / window
        else:
            out[i] = s / (i + 1)
    return out

def linear_epsilon(step: int, eps_start: float, eps_end: float, decay_steps: int) -> float:
    if decay_steps <= 0:
        return eps_end
    t = min(max(step, 0), decay_steps)
    return eps_start + (eps_end - eps_start) * (t / decay_steps)

def print_theory_cartpole() -> None:
    print("=== 1) Dynamics (Cart-Pole) / 力学モデル(要点) ===")
    print("State x (observation in CartPole-v1):")
    print("  x = [cart_position, cart_velocity, pole_angle, pole_angular_velocity]")
    print("Units:")
    print("  cart_position [m], cart_velocity [m/s], pole_angle [rad], pole_angular_velocity [rad/s]")
    print("Conceptual form:")
    print("  x_dot = f(x, u), u is horizontal force (discrete left/right).")
    print("Linearization idea (around upright):")
    print("  For small |theta|, sin(theta)≈theta, cos(theta)≈1 -> x_dot ≈ A x + B u (LQR style).")
    print("")
    print("=== 2) MDP / MDP定式化 ===")
    print("S: 4D observation, A: {0,1}, R: +1 per step, Terminal: angle/position limit or time limit, gamma: 0.99")
    print("")
    print("=== 3) DQN / DQN ===")
    print("Bellman optimality:")
    print("  Q*(s,a) = E[ r + γ * max_a' Q*(s',a') ]")
    print("DQN uses replay buffer + target network to stabilize TD learning.")

# ============================================================
# 4. REPLAY BUFFER / 経験再生バッファ
# ============================================================
@dataclass
class ReplayBuffer:
    capacity: int
    obs_dim: int

    def __post_init__(self):
        self.obs = np.zeros((self.capacity, self.obs_dim), dtype=np.float32)
        self.next_obs = np.zeros((self.capacity, self.obs_dim), dtype=np.float32)
        self.acts = np.zeros((self.capacity,), dtype=np.int64)
        self.rews = np.zeros((self.capacity,), dtype=np.float32)
        self.dones = np.zeros((self.capacity,), dtype=np.float32)
        self.ptr = 0
        self.size = 0

    def add(self, o: np.ndarray, a: int, r: float, no: np.ndarray, d: bool) -> None:
        i = self.ptr
        self.obs[i] = o
        self.acts[i] = a
        self.rews[i] = r
        self.next_obs[i] = no
        self.dones[i] = 1.0 if d else 0.0
        self.ptr = (self.ptr + 1) % self.capacity
        self.size = min(self.size + 1, self.capacity)

    def sample(self, batch_size: int, rng: np.random.Generator):
        idx = rng.integers(0, self.size, size=batch_size)
        return (self.obs[idx], self.acts[idx], self.rews[idx], self.next_obs[idx], self.dones[idx])

# ============================================================
# 5. Q-NETWORK / Qネットワーク
# ============================================================
class QNetwork(nn.Module):
    def __init__(self, obs_dim: int, act_dim: int, hidden: int):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(obs_dim, hidden),
            nn.ReLU(),
            nn.Linear(hidden, hidden),
            nn.ReLU(),
            nn.Linear(hidden, act_dim),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

# ============================================================
# 6. ENV + POLICY + UPDATE / 環境・方策・更新
# ============================================================
def make_env(env_id: str, seed: int, max_steps: int, record_video: bool=False, video_dir: str="/content/videos") -> gym.Env:
    env = gym.make(env_id, render_mode="rgb_array" if record_video else None)
    env = gym.wrappers.TimeLimit(env, max_episode_steps=max_steps)
    env.reset(seed=seed)
    env.action_space.seed(seed)
    env.observation_space.seed(seed)
    if record_video:
        os.makedirs(video_dir, exist_ok=True)
        env = gym.wrappers.RecordVideo(
            env,
            video_folder=video_dir,
            episode_trigger=lambda ep: True,
            name_prefix=f"{env_id}_{now_tag()}",
            disable_logger=True
        )
    return env

def select_action(qnet: QNetwork, obs: np.ndarray, eps: float, act_dim: int, device: str) -> int:
    if random.random() < eps:
        return random.randrange(act_dim)
    with torch.no_grad():
        x = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
        q = qnet(x)
        return int(torch.argmax(q, dim=1).item())

def dqn_update(qnet: QNetwork, target: QNetwork, opt: optim.Optimizer, batch, gamma: float, device: str, grad_clip_norm: float) -> float:
    obs, acts, rews, next_obs, dones = batch
    obs_t = torch.tensor(obs, dtype=torch.float32, device=device)
    acts_t = torch.tensor(acts, dtype=torch.int64, device=device).unsqueeze(1)
    rews_t = torch.tensor(rews, dtype=torch.float32, device=device).unsqueeze(1)
    next_obs_t = torch.tensor(next_obs, dtype=torch.float32, device=device)
    dones_t = torch.tensor(dones, dtype=torch.float32, device=device).unsqueeze(1)

    q_sa = qnet(obs_t).gather(1, acts_t)
    with torch.no_grad():
        max_next_q = target(next_obs_t).max(dim=1, keepdim=True).values
        td_target = rews_t + gamma * max_next_q * (1.0 - dones_t)

    loss = nn.functional.smooth_l1_loss(q_sa, td_target)

    opt.zero_grad(set_to_none=True)
    loss.backward()
    if grad_clip_norm is not None and grad_clip_norm > 0:
        nn.utils.clip_grad_norm_(qnet.parameters(), grad_clip_norm)
    opt.step()
    return float(loss.item())

def evaluate_policy(params: Dict[str, object], qnet: QNetwork, seed_offset: int = 2000, record_video: bool = True) -> Dict[str, float]:
    env = make_env(params["ENV_ID"], int(params["SEED"]) + seed_offset, int(params["MAX_EPISODE_STEPS"]),
                   record_video=record_video, video_dir=str(params["VIDEO_DIR"]))
    device = str(params["DEVICE"])
    act_dim = env.action_space.n

    returns, lengths = [], []
    for ep in range(int(params["EVAL_EPISODES"])):
        obs, _ = env.reset(seed=int(params["SEED"]) + seed_offset + ep)
        done = False
        trunc = False
        ep_ret = 0.0
        ep_len = 0
        while not (done or trunc):
            a = select_action(qnet, obs, eps=0.0, act_dim=act_dim, device=device)
            obs, r, done, trunc, _ = env.step(a)
            ep_ret += float(r)
            ep_len += 1
        returns.append(ep_ret)
        lengths.append(ep_len)

    env.close()
    return {
        "eval_return_mean": float(np.mean(returns)),
        "eval_return_std": float(np.std(returns)),
        "eval_len_mean": float(np.mean(lengths)),
    }

def show_latest_video(video_dir: str):
    # 日本語: RecordVideoが作ったmp4を表示(最新)
    # English: Display the newest mp4 in Colab output.
    if not os.path.isdir(video_dir):
        print("No video directory found.")
        return
    mp4s = []
    for root, _, files in os.walk(video_dir):
        for f in files:
            if f.lower().endswith(".mp4"):
                mp4s.append(os.path.join(root, f))
    if not mp4s:
        print("No mp4 found.")
        return
    mp4s.sort(key=lambda p: os.path.getmtime(p))
    path = mp4s[-1]
    print("Video:", path)
    with open(path, "rb") as f:
        data = f.read()
    import base64
    b64 = base64.b64encode(data).decode("ascii")
    display(HTML(f"""
    <video width="640" controls>
      <source src="data:video/mp4;base64,{b64}" type="video/mp4">
    </video>
    """))

# ============================================================
# 7. TRAIN LOOP / 学習ループ
# ============================================================
def train_dqn(params: Dict[str, object]) -> Dict[str, object]:
    set_global_seed(int(params["SEED"]))
    rng = np.random.default_rng(int(params["SEED"]))

    env = make_env(str(params["ENV_ID"]), int(params["SEED"]), int(params["MAX_EPISODE_STEPS"]), record_video=False)

    obs_dim = env.observation_space.shape[0]
    act_dim = env.action_space.n
    device = str(params["DEVICE"])

    qnet = QNetwork(obs_dim, act_dim, int(params["HIDDEN_SIZE"])).to(device)
    target = QNetwork(obs_dim, act_dim, int(params["HIDDEN_SIZE"])).to(device)
    target.load_state_dict(qnet.state_dict())
    target.eval()

    opt = optim.Adam(qnet.parameters(), lr=float(params["LR"]))
    rb = ReplayBuffer(int(params["REPLAY_SIZE"]), obs_dim)

    episode_returns, episode_lengths, losses = [], [], []
    steps_total = 0

    for ep in range(int(params["EPISODES"])):
        obs, _ = env.reset(seed=int(params["SEED"]) + ep)
        done = False
        trunc = False
        ep_ret = 0.0
        ep_len = 0

        while not (done or trunc):
            eps = linear_epsilon(
                steps_total,
                float(params["EPS_START"]),
                float(params["EPS_END"]),
                int(params["EPS_DECAY_STEPS"])
            )
            a = select_action(qnet, obs, eps, act_dim, device)
            next_obs, r, done, trunc, _ = env.step(a)

            rb.add(obs, a, float(r), next_obs, bool(done or trunc))

            obs = next_obs
            ep_ret += float(r)
            ep_len += 1
            steps_total += 1

            if rb.size >= int(params["LEARNING_STARTS"]) and (steps_total % int(params["TRAIN_FREQ"]) == 0):
                batch = rb.sample(int(params["BATCH_SIZE"]), rng)
                loss_val = dqn_update(
                    qnet, target, opt, batch,
                    gamma=float(params["GAMMA"]),
                    device=device,
                    grad_clip_norm=float(params["GRAD_CLIP_NORM"])
                )
                losses.append(loss_val)

            if steps_total % int(params["TARGET_UPDATE"]) == 0:
                target.load_state_dict(qnet.state_dict())

        episode_returns.append(ep_ret)
        episode_lengths.append(ep_len)

        if (ep + 1) % max(1, int(params["PLOT_EVERY"])) == 0:
            recent = episode_returns[-int(params["PLOT_EVERY"]):]
            print(f"[Train] ep={ep+1:4d}/{params['EPISODES']}  "
                  f"return_mean_recent={np.mean(recent):7.2f}  steps_total={steps_total:7d}  eps={eps:5.3f}")

    env.close()

    return {
        "qnet": qnet,
        "target": target,
        "episode_returns": np.array(episode_returns, dtype=np.float64),
        "episode_lengths": np.array(episode_lengths, dtype=np.float64),
        "losses": np.array(losses, dtype=np.float64),
        "steps_total": steps_total,
        "params": params.copy(),
    }

# ============================================================
# 8. PLOT / プロット
# ============================================================
def plot_training_curves(logs: Dict[str, object], window: int = 20, plot_mode: str = "return") -> None:
    returns = logs["episode_returns"]
    losses = logs["losses"]

    plt.figure()
    if plot_mode == "loss":
        if losses.size == 0:
            plt.title("Training Loss (no data)")
            plt.xlabel("Update step")
            plt.ylabel("Huber loss")
            plt.grid(True)
            plt.show()
            return
        plt.plot(losses)
        plt.title("Training Loss (Huber)")
        plt.xlabel("Update step")
        plt.ylabel("Huber loss")
        plt.grid(True)
    else:
        ma = moving_average_numba(returns.astype(np.float64), window)
        plt.plot(returns, label="Return")
        plt.plot(ma, label=f"MovingAvg(window={window})")
        plt.title("Episode Return")
        plt.xlabel("Episode")
        plt.ylabel("Return")
        plt.grid(True)
        plt.legend()
    plt.show()

def save_plot_files(logs: Dict[str, object], window: int = 20) -> Tuple[str, str]:
    tag = now_tag()
    png_path = f"/content/train_curve_{tag}.png"
    pdf_path = f"/content/train_curve_{tag}.pdf"

    returns = logs["episode_returns"]
    ma = moving_average_numba(returns.astype(np.float64), window)

    plt.figure()
    plt.plot(returns, label="Return")
    plt.plot(ma, label=f"MovingAvg(window={window})")
    plt.title("Episode Return")
    plt.xlabel("Episode")
    plt.ylabel("Return")
    plt.grid(True)
    plt.legend()
    plt.savefig(png_path, dpi=150, bbox_inches="tight")
    plt.savefig(pdf_path, bbox_inches="tight")
    plt.close()
    return png_path, pdf_path

# ============================================================
# 9. UI / UI
# ============================================================
def build_ui_and_run():
    w_episodes = widgets.IntSlider(value=int(PARAM_INIT["EPISODES"]), min=50, max=2000, step=50, description="EPISODES")
    w_lr = widgets.FloatLogSlider(value=float(PARAM_INIT["LR"]), base=10, min=-5, max=-2, step=0.1, description="LR")
    w_gamma = widgets.FloatSlider(value=float(PARAM_INIT["GAMMA"]), min=0.90, max=0.999, step=0.001, description="GAMMA")
    w_batch = widgets.IntSlider(value=int(PARAM_INIT["BATCH_SIZE"]), min=32, max=512, step=32, description="BATCH")
    w_hidden = widgets.IntSlider(value=int(PARAM_INIT["HIDDEN_SIZE"]), min=64, max=512, step=64, description="HIDDEN")
    w_eps_start = widgets.FloatSlider(value=float(PARAM_INIT["EPS_START"]), min=0.1, max=1.0, step=0.05, description="EPS_START")
    w_eps_end = widgets.FloatSlider(value=float(PARAM_INIT["EPS_END"]), min=0.0, max=0.2, step=0.01, description="EPS_END")
    w_eps_decay = widgets.IntSlider(value=int(PARAM_INIT["EPS_DECAY_STEPS"]), min=5_000, max=200_000, step=5_000, description="EPS_DECAY")
    w_target = widgets.IntSlider(value=int(PARAM_INIT["TARGET_UPDATE"]), min=200, max=5000, step=200, description="TARGET_UPD")
    w_plot_every = widgets.IntSlider(value=int(PARAM_INIT["PLOT_EVERY"]), min=10, max=200, step=10, description="PLOT_EVERY")
    w_plot_mode = widgets.Dropdown(options=["return", "loss"], value=str(PARAM_INIT["PLOT_MODE"]), description="PLOT_MODE")
    w_seed = widgets.IntText(value=int(PARAM_INIT["SEED"]), description="SEED")

    btn_train = widgets.Button(description="Train", button_style="success")
    btn_eval = widgets.Button(description="Evaluate + Video", button_style="info")
    btn_reset = widgets.Button(description="Reset params", button_style="warning")

    out = widgets.Output()
    state = {"logs": None}

    def _apply_params():
        PARAM_INIT["EPISODES"] = int(w_episodes.value)
        PARAM_INIT["LR"] = float(w_lr.value)
        PARAM_INIT["GAMMA"] = float(w_gamma.value)
        PARAM_INIT["BATCH_SIZE"] = int(w_batch.value)
        PARAM_INIT["HIDDEN_SIZE"] = int(w_hidden.value)
        PARAM_INIT["EPS_START"] = float(w_eps_start.value)
        PARAM_INIT["EPS_END"] = float(w_eps_end.value)
        PARAM_INIT["EPS_DECAY_STEPS"] = int(w_eps_decay.value)
        PARAM_INIT["TARGET_UPDATE"] = int(w_target.value)
        PARAM_INIT["PLOT_EVERY"] = int(w_plot_every.value)
        PARAM_INIT["PLOT_MODE"] = str(w_plot_mode.value)
        PARAM_INIT["SEED"] = int(w_seed.value)
        PARAM_INIT["DEVICE"] = "cuda" if torch.cuda.is_available() else "cpu"

    def on_reset(_):
        w_episodes.value = 600
        w_lr.value = 1e-3
        w_gamma.value = 0.99
        w_batch.value = 128
        w_hidden.value = 256
        w_eps_start.value = 1.0
        w_eps_end.value = 0.05
        w_eps_decay.value = 50_000
        w_target.value = 1_000
        w_plot_every.value = 20
        w_plot_mode.value = "return"
        w_seed.value = 42

    def on_train(_):
        with out:
            clear_output(wait=True)
            _apply_params()
            print_env_info(PARAM_INIT)
            print_param_table(PARAM_INIT)
            print_theory_cartpole()

            try:
                logs = train_dqn(PARAM_INIT)
            except Exception as e:
                print("[ERROR] Training failed:", e)
                return

            state["logs"] = logs

            print("\n=== Training Done ===")
            returns = logs["episode_returns"]
            print(f"Final return (last ep): {returns[-1]:.2f}")
            print(f"Mean return (last 20 eps): {np.mean(returns[-20:]):.2f}")

            plot_training_curves(logs, window=20, plot_mode=str(PARAM_INIT["PLOT_MODE"]))

            if bool(PARAM_INIT["SAVE_PLOTS"]):
                png_path, pdf_path = save_plot_files(logs, window=20)
                print(f"Saved plot: {png_path}")
                print(f"Saved plot: {pdf_path}")

    def on_eval(_):
        with out:
            clear_output(wait=True)
            if state["logs"] is None:
                print("[ERROR] Train first.")
                return
            qnet = state["logs"]["qnet"]
            print("=== Evaluating (greedy) and recording video ===")
            try:
                metrics = evaluate_policy(PARAM_INIT, qnet, seed_offset=2000, record_video=True)
            except Exception as e:
                print("[ERROR] Evaluation failed:", e)
                return
            for k, v in metrics.items():
                print(f"{k}: {v}")
            print(f"Video dir: {PARAM_INIT['VIDEO_DIR']}")
            show_latest_video(str(PARAM_INIT["VIDEO_DIR"]))

    btn_reset.on_click(on_reset)
    btn_train.on_click(on_train)
    btn_eval.on_click(on_eval)

    ui_left = widgets.VBox([
        w_episodes, w_lr, w_gamma, w_batch, w_hidden,
        w_eps_start, w_eps_end, w_eps_decay, w_target,
        w_plot_every, w_plot_mode, w_seed
    ])
    ui_right = widgets.VBox([btn_train, btn_eval, btn_reset])

    display(widgets.HBox([ui_left, ui_right]))
    display(out)

# ============================================================
# 10. RUN / 実行
# ============================================================
build_ui_and_run()
0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?