# Program Name: rl_cartpole_dqn_colab_allinone_v2.py
# Creation Date: 20260125
# Purpose: Reproducible CartPole DQN (Gymnasium) in Google Colab with pinned deps, auto-restart, UI, logging, plots, and video.
# ============================================================
# 0. SETUP (Pinned install + auto runtime restart) / セットアップ(版固定+自動再起動)
# ============================================================
# 日本語: NumPy ABI不整合(dtype size changed)を避けるため、依存を版固定で入れ直し、必要なら再起動する。
# English: Pin dependencies to avoid NumPy binary-compat issues and restart runtime if needed.
import os, sys, subprocess, textwrap, signal
SETUP_MARK = "/content/rl_setup_done_cartpole_v2.txt"
PINNED = {
"numpy": "1.26.4",
"matplotlib": "3.8.4",
"ipywidgets": "8.1.1",
"numba": "0.59.1",
"gymnasium": "0.29.1",
"imageio": "2.34.1",
"imageio-ffmpeg": "0.5.1",
}
def _run(cmd):
subprocess.check_call(cmd)
def _pip(args):
_run([sys.executable, "-m", "pip"] + args)
def setup_and_maybe_restart():
# Colab判定 / Detect Colab
in_colab = ("google.colab" in sys.modules) or ("COLAB_GPU" in os.environ) or (os.path.exists("/content"))
if not in_colab:
return # ローカルは再起動を強制しない / Do not force restart outside Colab
if os.path.exists(SETUP_MARK):
return
print("=== Installing pinned dependencies (first run) ===")
_pip(["install", "-q", "--upgrade", "pip"])
# 既存の不整合を消すために一度アンインストール / Uninstall potentially conflicting packages
_pip(["uninstall", "-y", "numpy", "gymnasium", "gymnasium-classic-control", "pyglet", "pygame"])
_pip(["install", "-q",
f"numpy=={PINNED['numpy']}",
f"matplotlib=={PINNED['matplotlib']}",
f"ipywidgets=={PINNED['ipywidgets']}",
f"numba=={PINNED['numba']}",
f"gymnasium=={PINNED['gymnasium']}",
f"gymnasium[classic-control]=={PINNED['gymnasium']}",
f"imageio=={PINNED['imageio']}",
f"imageio-ffmpeg=={PINNED['imageio-ffmpeg']}",
])
# torchはColabに通常入っているが、無い場合だけ / Torch is usually preinstalled in Colab
try:
import torch # noqa
except Exception:
_pip(["install", "-q", "torch"])
with open(SETUP_MARK, "w") as f:
f.write("ok")
print("\n=== Runtime restart required ===")
print("Restarting now to finalize binary compatibility...")
os.kill(os.getpid(), signal.SIGKILL)
setup_and_maybe_restart()
# ============================================================
# 1. IMPORTS / インポート
# ============================================================
import time
import math
import random
import datetime
from dataclasses import dataclass
from typing import Dict, Tuple, List, Optional
import numpy as np
import matplotlib.pyplot as plt
import gymnasium as gym
import ipywidgets as widgets
from IPython.display import display, clear_output, HTML
from numba import njit
import torch
import torch.nn as nn
import torch.optim as optim
# Colab widget manager / Colabのウィジェット有効化
try:
from google.colab import output as colab_output
colab_output.enable_custom_widget_manager()
except Exception:
pass
# ============================================================
# 2. PARAM_INIT (Single source of truth) / 設定一元管理
# ============================================================
PARAM_INIT: Dict[str, object] = {
# --- Environment / 環境 ---
"ENV_ID": "CartPole-v1",
"MAX_EPISODE_STEPS": 500,
"SEED": 42,
# --- DQN / 学習設定 ---
"EPISODES": 600,
"GAMMA": 0.99,
"LR": 1e-3,
"BATCH_SIZE": 128,
"REPLAY_SIZE": 50_000,
"LEARNING_STARTS": 1_000,
"TRAIN_FREQ": 1,
"TARGET_UPDATE": 1_000,
"GRAD_CLIP_NORM": 10.0,
# --- Exploration / 探索 ---
"EPS_START": 1.0,
"EPS_END": 0.05,
"EPS_DECAY_STEPS": 50_000,
# --- Network / NN ---
"HIDDEN_SIZE": 256,
# --- Device / デバイス ---
"DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
# --- Evaluation / 評価 ---
"EVAL_EPISODES": 20,
"SUCCESS_SCORE": 475.0,
"VIDEO_RECORD": True,
"VIDEO_DIR": "/content/videos_cartpole",
# --- Plot / 可視化 ---
"PLOT_EVERY": 20,
"PLOT_MODE": "return", # "return" or "loss"
"SAVE_PLOTS": True,
}
# ============================================================
# 3. UTILITIES / ユーティリティ
# ============================================================
def now_tag() -> str:
return datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
def set_global_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def print_env_info(params: Dict[str, object]) -> None:
print("=== Runtime Info ===")
print(f"Time: {datetime.datetime.now().isoformat(timespec='seconds')}")
print(f"Env: {params['ENV_ID']}")
print(f"Device: {params['DEVICE']}")
print(f"python: {sys.version.split()[0]}")
print(f"torch: {torch.__version__}")
print(f"gymnasium: {gym.__version__}")
print(f"numpy: {np.__version__}")
print(f"matplotlib: {plt.matplotlib.__version__}")
def print_param_table(params: Dict[str, object]) -> None:
keys = [
"ENV_ID","SEED","EPISODES","GAMMA","LR","BATCH_SIZE","REPLAY_SIZE",
"LEARNING_STARTS","TRAIN_FREQ","TARGET_UPDATE","EPS_START","EPS_END",
"EPS_DECAY_STEPS","HIDDEN_SIZE","DEVICE","EVAL_EPISODES","SUCCESS_SCORE"
]
print("=== PARAM_INIT (selected) ===")
for k in keys:
print(f"{k:>16s}: {params[k]}")
@njit
def moving_average_numba(x: np.ndarray, window: int) -> np.ndarray:
n = x.size
if window <= 1:
return x.copy()
out = np.empty(n, dtype=np.float64)
s = 0.0
for i in range(n):
s += x[i]
if i >= window:
s -= x[i - window]
out[i] = s / window
else:
out[i] = s / (i + 1)
return out
def linear_epsilon(step: int, eps_start: float, eps_end: float, decay_steps: int) -> float:
if decay_steps <= 0:
return eps_end
t = min(max(step, 0), decay_steps)
return eps_start + (eps_end - eps_start) * (t / decay_steps)
def print_theory_cartpole() -> None:
print("=== 1) Dynamics (Cart-Pole) / 力学モデル(要点) ===")
print("State x (observation in CartPole-v1):")
print(" x = [cart_position, cart_velocity, pole_angle, pole_angular_velocity]")
print("Units:")
print(" cart_position [m], cart_velocity [m/s], pole_angle [rad], pole_angular_velocity [rad/s]")
print("Conceptual form:")
print(" x_dot = f(x, u), u is horizontal force (discrete left/right).")
print("Linearization idea (around upright):")
print(" For small |theta|, sin(theta)≈theta, cos(theta)≈1 -> x_dot ≈ A x + B u (LQR style).")
print("")
print("=== 2) MDP / MDP定式化 ===")
print("S: 4D observation, A: {0,1}, R: +1 per step, Terminal: angle/position limit or time limit, gamma: 0.99")
print("")
print("=== 3) DQN / DQN ===")
print("Bellman optimality:")
print(" Q*(s,a) = E[ r + γ * max_a' Q*(s',a') ]")
print("DQN uses replay buffer + target network to stabilize TD learning.")
# ============================================================
# 4. REPLAY BUFFER / 経験再生バッファ
# ============================================================
@dataclass
class ReplayBuffer:
capacity: int
obs_dim: int
def __post_init__(self):
self.obs = np.zeros((self.capacity, self.obs_dim), dtype=np.float32)
self.next_obs = np.zeros((self.capacity, self.obs_dim), dtype=np.float32)
self.acts = np.zeros((self.capacity,), dtype=np.int64)
self.rews = np.zeros((self.capacity,), dtype=np.float32)
self.dones = np.zeros((self.capacity,), dtype=np.float32)
self.ptr = 0
self.size = 0
def add(self, o: np.ndarray, a: int, r: float, no: np.ndarray, d: bool) -> None:
i = self.ptr
self.obs[i] = o
self.acts[i] = a
self.rews[i] = r
self.next_obs[i] = no
self.dones[i] = 1.0 if d else 0.0
self.ptr = (self.ptr + 1) % self.capacity
self.size = min(self.size + 1, self.capacity)
def sample(self, batch_size: int, rng: np.random.Generator):
idx = rng.integers(0, self.size, size=batch_size)
return (self.obs[idx], self.acts[idx], self.rews[idx], self.next_obs[idx], self.dones[idx])
# ============================================================
# 5. Q-NETWORK / Qネットワーク
# ============================================================
class QNetwork(nn.Module):
def __init__(self, obs_dim: int, act_dim: int, hidden: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(obs_dim, hidden),
nn.ReLU(),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Linear(hidden, act_dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
# ============================================================
# 6. ENV + POLICY + UPDATE / 環境・方策・更新
# ============================================================
def make_env(env_id: str, seed: int, max_steps: int, record_video: bool=False, video_dir: str="/content/videos") -> gym.Env:
env = gym.make(env_id, render_mode="rgb_array" if record_video else None)
env = gym.wrappers.TimeLimit(env, max_episode_steps=max_steps)
env.reset(seed=seed)
env.action_space.seed(seed)
env.observation_space.seed(seed)
if record_video:
os.makedirs(video_dir, exist_ok=True)
env = gym.wrappers.RecordVideo(
env,
video_folder=video_dir,
episode_trigger=lambda ep: True,
name_prefix=f"{env_id}_{now_tag()}",
disable_logger=True
)
return env
def select_action(qnet: QNetwork, obs: np.ndarray, eps: float, act_dim: int, device: str) -> int:
if random.random() < eps:
return random.randrange(act_dim)
with torch.no_grad():
x = torch.tensor(obs, dtype=torch.float32, device=device).unsqueeze(0)
q = qnet(x)
return int(torch.argmax(q, dim=1).item())
def dqn_update(qnet: QNetwork, target: QNetwork, opt: optim.Optimizer, batch, gamma: float, device: str, grad_clip_norm: float) -> float:
obs, acts, rews, next_obs, dones = batch
obs_t = torch.tensor(obs, dtype=torch.float32, device=device)
acts_t = torch.tensor(acts, dtype=torch.int64, device=device).unsqueeze(1)
rews_t = torch.tensor(rews, dtype=torch.float32, device=device).unsqueeze(1)
next_obs_t = torch.tensor(next_obs, dtype=torch.float32, device=device)
dones_t = torch.tensor(dones, dtype=torch.float32, device=device).unsqueeze(1)
q_sa = qnet(obs_t).gather(1, acts_t)
with torch.no_grad():
max_next_q = target(next_obs_t).max(dim=1, keepdim=True).values
td_target = rews_t + gamma * max_next_q * (1.0 - dones_t)
loss = nn.functional.smooth_l1_loss(q_sa, td_target)
opt.zero_grad(set_to_none=True)
loss.backward()
if grad_clip_norm is not None and grad_clip_norm > 0:
nn.utils.clip_grad_norm_(qnet.parameters(), grad_clip_norm)
opt.step()
return float(loss.item())
def evaluate_policy(params: Dict[str, object], qnet: QNetwork, seed_offset: int = 2000, record_video: bool = True) -> Dict[str, float]:
env = make_env(params["ENV_ID"], int(params["SEED"]) + seed_offset, int(params["MAX_EPISODE_STEPS"]),
record_video=record_video, video_dir=str(params["VIDEO_DIR"]))
device = str(params["DEVICE"])
act_dim = env.action_space.n
returns, lengths = [], []
for ep in range(int(params["EVAL_EPISODES"])):
obs, _ = env.reset(seed=int(params["SEED"]) + seed_offset + ep)
done = False
trunc = False
ep_ret = 0.0
ep_len = 0
while not (done or trunc):
a = select_action(qnet, obs, eps=0.0, act_dim=act_dim, device=device)
obs, r, done, trunc, _ = env.step(a)
ep_ret += float(r)
ep_len += 1
returns.append(ep_ret)
lengths.append(ep_len)
env.close()
return {
"eval_return_mean": float(np.mean(returns)),
"eval_return_std": float(np.std(returns)),
"eval_len_mean": float(np.mean(lengths)),
}
def show_latest_video(video_dir: str):
# 日本語: RecordVideoが作ったmp4を表示(最新)
# English: Display the newest mp4 in Colab output.
if not os.path.isdir(video_dir):
print("No video directory found.")
return
mp4s = []
for root, _, files in os.walk(video_dir):
for f in files:
if f.lower().endswith(".mp4"):
mp4s.append(os.path.join(root, f))
if not mp4s:
print("No mp4 found.")
return
mp4s.sort(key=lambda p: os.path.getmtime(p))
path = mp4s[-1]
print("Video:", path)
with open(path, "rb") as f:
data = f.read()
import base64
b64 = base64.b64encode(data).decode("ascii")
display(HTML(f"""
<video width="640" controls>
<source src="data:video/mp4;base64,{b64}" type="video/mp4">
</video>
"""))
# ============================================================
# 7. TRAIN LOOP / 学習ループ
# ============================================================
def train_dqn(params: Dict[str, object]) -> Dict[str, object]:
set_global_seed(int(params["SEED"]))
rng = np.random.default_rng(int(params["SEED"]))
env = make_env(str(params["ENV_ID"]), int(params["SEED"]), int(params["MAX_EPISODE_STEPS"]), record_video=False)
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n
device = str(params["DEVICE"])
qnet = QNetwork(obs_dim, act_dim, int(params["HIDDEN_SIZE"])).to(device)
target = QNetwork(obs_dim, act_dim, int(params["HIDDEN_SIZE"])).to(device)
target.load_state_dict(qnet.state_dict())
target.eval()
opt = optim.Adam(qnet.parameters(), lr=float(params["LR"]))
rb = ReplayBuffer(int(params["REPLAY_SIZE"]), obs_dim)
episode_returns, episode_lengths, losses = [], [], []
steps_total = 0
for ep in range(int(params["EPISODES"])):
obs, _ = env.reset(seed=int(params["SEED"]) + ep)
done = False
trunc = False
ep_ret = 0.0
ep_len = 0
while not (done or trunc):
eps = linear_epsilon(
steps_total,
float(params["EPS_START"]),
float(params["EPS_END"]),
int(params["EPS_DECAY_STEPS"])
)
a = select_action(qnet, obs, eps, act_dim, device)
next_obs, r, done, trunc, _ = env.step(a)
rb.add(obs, a, float(r), next_obs, bool(done or trunc))
obs = next_obs
ep_ret += float(r)
ep_len += 1
steps_total += 1
if rb.size >= int(params["LEARNING_STARTS"]) and (steps_total % int(params["TRAIN_FREQ"]) == 0):
batch = rb.sample(int(params["BATCH_SIZE"]), rng)
loss_val = dqn_update(
qnet, target, opt, batch,
gamma=float(params["GAMMA"]),
device=device,
grad_clip_norm=float(params["GRAD_CLIP_NORM"])
)
losses.append(loss_val)
if steps_total % int(params["TARGET_UPDATE"]) == 0:
target.load_state_dict(qnet.state_dict())
episode_returns.append(ep_ret)
episode_lengths.append(ep_len)
if (ep + 1) % max(1, int(params["PLOT_EVERY"])) == 0:
recent = episode_returns[-int(params["PLOT_EVERY"]):]
print(f"[Train] ep={ep+1:4d}/{params['EPISODES']} "
f"return_mean_recent={np.mean(recent):7.2f} steps_total={steps_total:7d} eps={eps:5.3f}")
env.close()
return {
"qnet": qnet,
"target": target,
"episode_returns": np.array(episode_returns, dtype=np.float64),
"episode_lengths": np.array(episode_lengths, dtype=np.float64),
"losses": np.array(losses, dtype=np.float64),
"steps_total": steps_total,
"params": params.copy(),
}
# ============================================================
# 8. PLOT / プロット
# ============================================================
def plot_training_curves(logs: Dict[str, object], window: int = 20, plot_mode: str = "return") -> None:
returns = logs["episode_returns"]
losses = logs["losses"]
plt.figure()
if plot_mode == "loss":
if losses.size == 0:
plt.title("Training Loss (no data)")
plt.xlabel("Update step")
plt.ylabel("Huber loss")
plt.grid(True)
plt.show()
return
plt.plot(losses)
plt.title("Training Loss (Huber)")
plt.xlabel("Update step")
plt.ylabel("Huber loss")
plt.grid(True)
else:
ma = moving_average_numba(returns.astype(np.float64), window)
plt.plot(returns, label="Return")
plt.plot(ma, label=f"MovingAvg(window={window})")
plt.title("Episode Return")
plt.xlabel("Episode")
plt.ylabel("Return")
plt.grid(True)
plt.legend()
plt.show()
def save_plot_files(logs: Dict[str, object], window: int = 20) -> Tuple[str, str]:
tag = now_tag()
png_path = f"/content/train_curve_{tag}.png"
pdf_path = f"/content/train_curve_{tag}.pdf"
returns = logs["episode_returns"]
ma = moving_average_numba(returns.astype(np.float64), window)
plt.figure()
plt.plot(returns, label="Return")
plt.plot(ma, label=f"MovingAvg(window={window})")
plt.title("Episode Return")
plt.xlabel("Episode")
plt.ylabel("Return")
plt.grid(True)
plt.legend()
plt.savefig(png_path, dpi=150, bbox_inches="tight")
plt.savefig(pdf_path, bbox_inches="tight")
plt.close()
return png_path, pdf_path
# ============================================================
# 9. UI / UI
# ============================================================
def build_ui_and_run():
w_episodes = widgets.IntSlider(value=int(PARAM_INIT["EPISODES"]), min=50, max=2000, step=50, description="EPISODES")
w_lr = widgets.FloatLogSlider(value=float(PARAM_INIT["LR"]), base=10, min=-5, max=-2, step=0.1, description="LR")
w_gamma = widgets.FloatSlider(value=float(PARAM_INIT["GAMMA"]), min=0.90, max=0.999, step=0.001, description="GAMMA")
w_batch = widgets.IntSlider(value=int(PARAM_INIT["BATCH_SIZE"]), min=32, max=512, step=32, description="BATCH")
w_hidden = widgets.IntSlider(value=int(PARAM_INIT["HIDDEN_SIZE"]), min=64, max=512, step=64, description="HIDDEN")
w_eps_start = widgets.FloatSlider(value=float(PARAM_INIT["EPS_START"]), min=0.1, max=1.0, step=0.05, description="EPS_START")
w_eps_end = widgets.FloatSlider(value=float(PARAM_INIT["EPS_END"]), min=0.0, max=0.2, step=0.01, description="EPS_END")
w_eps_decay = widgets.IntSlider(value=int(PARAM_INIT["EPS_DECAY_STEPS"]), min=5_000, max=200_000, step=5_000, description="EPS_DECAY")
w_target = widgets.IntSlider(value=int(PARAM_INIT["TARGET_UPDATE"]), min=200, max=5000, step=200, description="TARGET_UPD")
w_plot_every = widgets.IntSlider(value=int(PARAM_INIT["PLOT_EVERY"]), min=10, max=200, step=10, description="PLOT_EVERY")
w_plot_mode = widgets.Dropdown(options=["return", "loss"], value=str(PARAM_INIT["PLOT_MODE"]), description="PLOT_MODE")
w_seed = widgets.IntText(value=int(PARAM_INIT["SEED"]), description="SEED")
btn_train = widgets.Button(description="Train", button_style="success")
btn_eval = widgets.Button(description="Evaluate + Video", button_style="info")
btn_reset = widgets.Button(description="Reset params", button_style="warning")
out = widgets.Output()
state = {"logs": None}
def _apply_params():
PARAM_INIT["EPISODES"] = int(w_episodes.value)
PARAM_INIT["LR"] = float(w_lr.value)
PARAM_INIT["GAMMA"] = float(w_gamma.value)
PARAM_INIT["BATCH_SIZE"] = int(w_batch.value)
PARAM_INIT["HIDDEN_SIZE"] = int(w_hidden.value)
PARAM_INIT["EPS_START"] = float(w_eps_start.value)
PARAM_INIT["EPS_END"] = float(w_eps_end.value)
PARAM_INIT["EPS_DECAY_STEPS"] = int(w_eps_decay.value)
PARAM_INIT["TARGET_UPDATE"] = int(w_target.value)
PARAM_INIT["PLOT_EVERY"] = int(w_plot_every.value)
PARAM_INIT["PLOT_MODE"] = str(w_plot_mode.value)
PARAM_INIT["SEED"] = int(w_seed.value)
PARAM_INIT["DEVICE"] = "cuda" if torch.cuda.is_available() else "cpu"
def on_reset(_):
w_episodes.value = 600
w_lr.value = 1e-3
w_gamma.value = 0.99
w_batch.value = 128
w_hidden.value = 256
w_eps_start.value = 1.0
w_eps_end.value = 0.05
w_eps_decay.value = 50_000
w_target.value = 1_000
w_plot_every.value = 20
w_plot_mode.value = "return"
w_seed.value = 42
def on_train(_):
with out:
clear_output(wait=True)
_apply_params()
print_env_info(PARAM_INIT)
print_param_table(PARAM_INIT)
print_theory_cartpole()
try:
logs = train_dqn(PARAM_INIT)
except Exception as e:
print("[ERROR] Training failed:", e)
return
state["logs"] = logs
print("\n=== Training Done ===")
returns = logs["episode_returns"]
print(f"Final return (last ep): {returns[-1]:.2f}")
print(f"Mean return (last 20 eps): {np.mean(returns[-20:]):.2f}")
plot_training_curves(logs, window=20, plot_mode=str(PARAM_INIT["PLOT_MODE"]))
if bool(PARAM_INIT["SAVE_PLOTS"]):
png_path, pdf_path = save_plot_files(logs, window=20)
print(f"Saved plot: {png_path}")
print(f"Saved plot: {pdf_path}")
def on_eval(_):
with out:
clear_output(wait=True)
if state["logs"] is None:
print("[ERROR] Train first.")
return
qnet = state["logs"]["qnet"]
print("=== Evaluating (greedy) and recording video ===")
try:
metrics = evaluate_policy(PARAM_INIT, qnet, seed_offset=2000, record_video=True)
except Exception as e:
print("[ERROR] Evaluation failed:", e)
return
for k, v in metrics.items():
print(f"{k}: {v}")
print(f"Video dir: {PARAM_INIT['VIDEO_DIR']}")
show_latest_video(str(PARAM_INIT["VIDEO_DIR"]))
btn_reset.on_click(on_reset)
btn_train.on_click(on_train)
btn_eval.on_click(on_eval)
ui_left = widgets.VBox([
w_episodes, w_lr, w_gamma, w_batch, w_hidden,
w_eps_start, w_eps_end, w_eps_decay, w_target,
w_plot_every, w_plot_mode, w_seed
])
ui_right = widgets.VBox([btn_train, btn_eval, btn_reset])
display(widgets.HBox([ui_left, ui_right]))
display(out)
# ============================================================
# 10. RUN / 実行
# ============================================================
build_ui_and_run()