2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

クロスバリデーションとアーリーストッピング

Posted at

クロスバリデーションの統計学的根拠

主にhold-out法と呼ばれるバリデーションに対して、統計学的根拠が知りたくなり調べてみました。

irisデータセットを用いた統計的解析

import numpy as np
import pandas as pd
from sklearn.model_selection import KFold, cross_val_score
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from scipy.stats import t
import matplotlib.pyplot as plt

# データセットの読み込み (Loading the dataset)
data = load_iris()
X, y = data.data, data.target

# 分類器の設定 (Setting up the classifier)
model = RandomForestClassifier(random_state=42)

# k分割クロスバリデーション (k-fold cross-validation)
kf = KFold(n_splits=5, shuffle=True, random_state=42)
scores = cross_val_score(model, X, y, cv=kf)

# 統計的解析 (Statistical analysis)
mean_score = np.mean(scores)  # Mean score
std_dev_score = np.std(scores, ddof=1)  # Sample standard deviation
n = len(scores)  # Sample size
confidence_level = 0.95  # Confidence level
alpha = 1 - confidence_level
t_critical = t.ppf(1 - alpha/2, df=n-1)  # t critical value

# 信頼区間の計算 (Calculation of the confidence interval)
margin_of_error = t_critical * (std_dev_score / np.sqrt(n))
confidence_interval = (mean_score - margin_of_error, mean_score + margin_of_error)

# 結果の表示 (Displaying the results)
print("Scores for each fold:", scores)
print("Mean score:", mean_score)
print("Sample standard deviation:", std_dev_score)
print(f"{confidence_level*100:.1f}% Confidence interval:", confidence_interval)

# 結果を可視化 (Visualization of the results)
plt.figure(figsize=(8, 5))
plt.bar(range(1, len(scores) + 1), scores, color='skyblue', edgecolor='black', alpha=0.8)
plt.axhline(mean_score, color='red', linestyle='--', label=f'Mean score: {mean_score:.2f}')
plt.fill_between(
    range(0, len(scores) + 2),
    confidence_interval[0],
    confidence_interval[1],
    color='red',
    alpha=0.2,
    label=f'{confidence_level*100:.1f}% Confidence interval'
)
plt.xticks(range(1, len(scores) + 1), [f'Fold {i}' for i in range(1, len(scores) + 1)])
plt.ylim(0, 1)
plt.xlabel('Fold number')
plt.ylabel('Score')
plt.title('k-Fold Cross-Validation Scores and Confidence Interval')
plt.legend()
plt.grid(alpha=0.5)
plt.show()

image.png

  • 標本分布:クロスバリデーションで得られたスコアは、統計学的に母集団スコアの標本とみなせる。
  • 信頼区間:クロスバリデーションスコアの信頼区間を計算することで、モデル性能の推定精度を測定。
  • t分布の利用:スコア数が少ない場合、t分布が有効。

early stopping

モデルの過学習を防ぎ、汎化性能を向上させるために必要。
以下は、RandomForestClassifierを使用してアーリーストッピングを実現するコード。RandomForestClassifier自体にはアーリーストッピング機能が直接組み込まれていないが、カスタム実装で再現が可能。例えば、ランダムフォレストのツリー数を増やしながら、検証スコアが一定のパット数(ラウンド)改善しない場合に停止する仕組みを実現する。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# データセットの生成 (Generate a synthetic dataset)
X, y = make_classification(n_samples=1000, n_features=20, n_classes=2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X, y, test_size=0.2, random_state=42)

# パラメータ設定 (Set parameters)
max_trees = 500  # 最大のツリー数 (Maximum number of trees)
step_size = 10   # ツリー数を増やすステップ (Step size for trees)
patience = 10    # パット数 (Number of patience rounds)
best_score = 0   # 最良スコアの初期値 (Best score)
no_improve_rounds = 0  # 改善がないラウンド数 (Number of rounds without improvement)
early_stop = False  # アーリーストッピングフラグ (Early stopping flag)

train_scores = []
val_scores = []
n_trees_history = []

# アーリーストッピングの実装 (Early Stopping Implementation)
for n_trees in range(step_size, max_trees + 1, step_size):
    model = RandomForestClassifier(n_estimators=n_trees, random_state=42)
    model.fit(X_train, y_train)
    
    # 訓練スコアと検証スコアの計算 (Calculate training and validation scores)
    train_score = accuracy_score(y_train, model.predict(X_train))
    val_score = accuracy_score(y_val, model.predict(X_val))
    
    train_scores.append(train_score)
    val_scores.append(val_score)
    n_trees_history.append(n_trees)
    
    # スコアの改善をチェック (Check for improvement)
    if val_score > best_score:
        best_score = val_score
        no_improve_rounds = 0  # 改善があればリセット (Reset the counter)
    else:
        no_improve_rounds += 1
    
    if no_improve_rounds >= patience:
        early_stop = True
        print(f"Early stopping at {n_trees} trees with best validation score: {best_score:.4f}")
        break

# 訓練履歴の可視化 (Visualize training history)
plt.figure(figsize=(10, 6))
plt.plot(n_trees_history, train_scores, label="Training Accuracy", color="blue")
plt.plot(n_trees_history, val_scores, label="Validation Accuracy", color="orange")
if early_stop:
    plt.axvline(n_trees, color="red", linestyle="--", label="Early Stopping Point")
plt.xlabel("Number of Trees")
plt.ylabel("Accuracy")
plt.title("Training and Validation Accuracy with Early Stopping")
plt.legend()
plt.grid(alpha=0.5)
plt.show()

# 最終スコアの表示 (Display final scores)
print(f"Best Validation Accuracy: {best_score:.4f}")

image.png

oblivious decision tree

各深さの分岐の条件式がすべて同じ決定木。単純で解釈しやすい構造を持つ特殊な決定木であり、各レベルで同じ特徴量に基づいて分岐する。この特性により、ODTはビジネスコンテキストでも有効。

image.png

以下のような簡易的なサンプルデータセットに適用してみる。

np.random.seed(42)
customer_data = pd.DataFrame({
    "Customer_Age": np.random.randint(18, 70, 100),
    "Customer_Segment": np.random.choice(["High Value", "Low Value"], 100),
    "Support_Team": np.random.choice(["Technical", "Billing", "General"], 100)
})

# Encode categorical columns
customer_data["Customer_Segment"] = customer_data["Customer_Segment"].map({"High Value": 1, "Low Value": 0})
customer_data["Support_Team"] = customer_data["Support_Team"].map({"Technical": 0, "Billing": 1, "General": 2})

image.png

2
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?