2
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ディープラーニングによる二値分類タスク:閾値の操作とPyTorchの基本

Posted at

はじめに

二値分類問題は、データが二つのクラスのいずれかに属するかを判断するタスクです。メールがスパムか否か、画像に猫が含まれるかどうかなど、多くのアプリケーションで見られます。この記事では、PyTorchを用いた二値分類モデルの構築と、予測における閾値の適切な設定方法に焦点を当てます。

PyTorchとは?

PyTorchは、FacebookのAIリサーチチームによって開発されたオープンソースの機械学習ライブラリで、その柔軟性とパワフルな機能により、研究者や開発者から広く支持されています。PyTorchは、動的な計算グラフを提供し、複雑なアーキテクチャの構築とデバッグを容易にします。

データの準備と前処理

モデルの訓練に先立ち、データの準備と前処理が必要です。データをモデルに供給する前に、特徴量のスケーリングや正規化を行うことで、訓練プロセスを安定させ、収束を早めることができます。以下は、データの読み込み、標準化、および訓練セットと検証セットへの分割を行う基本的な手順です。

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

train_df = pd.read_csv('path/to/train_data.csv')
X = train_df.drop('target_column', axis=1).values
y = train_df['target_column'].values

# データの標準化
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X)

X_train, X_val, y_train, y_val = train_test_split(X_scaled, y, test_size=0.2, random_state=42)

モデルの構築

PyTorchを用いて、二値分類のためのシンプルなニューラルネットワークを構築します。ここでは、複数の全結合層を持つ基本的なアーキテクチャを使用します。

import torch
import torch.nn as nn
import torch.optim as optim

class BinaryClassifier(nn.Module):
    def __init__(self, num_features):
        super(BinaryClassifier, self).__init__()
        self.fc1 = nn.Linear(num_features, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return torch.sigmoid(self.fc3(x))

model = BinaryClassifier(X_train.shape[1])

訓練プロセス

モデルの訓練には、適切な損失関数と最適化アルゴリズムの選択が重要です。二値分類では、二項交差エントロピー損失が一般的に使用されます。

criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

閾値の最適化

モデルが予測する値(通常は確率)をクラスラベルに変換するためには閾値が必要です。多くのケースで0.5が使用されますが、これは必ずしも最適な選択ではありません。特に、データセットが不均衡な場合や、偽陽性と偽陰性のコストが異なる場合には、閾値を調整することでモデルの性能を向上させることができます。

検証データセットを使用してモデルの性能を評価し、F1スコアなどの指標を用いて閾値を決定します。

threshold = 0.2  # 閾値を0.2に設定
y_val_pred = []
with torch.no_grad():
    for inputs, labels in val_loader:
        outputs = model(inputs).squeeze()
        predicted = (outputs >= threshold).float()
        y_val_pred.extend(predicted.numpy())

mean_f1 = f1_score(y_val, np.array(y_val_pred), average='macro')
print(f'Mean F1 Score: {mean_f1}')

テストデータでの予測とサブミッション

最後に、選択した閾値を使用してテストデータセットで予測を行い、結果をサブミッションファイルとして保存します。

predictions = []
with torch.no_grad():
    for inputs in test_loader:
        outputs = model(inputs).squeeze()
        predictions.extend((outputs >= threshold).float().numpy())

sample_submission[1] = np.where(np.array(predictions) >= threshold, 1, 0)
sample_submission.to_csv('/content/drive/MyDrive/fin_tech/submission_nn_.csv', header=None)
print('Submission file has been created.')

最後に

閾値の適切な設定は、二値分類タスクにおけるモデルの性能を大きく左右します。PyTorchを使用したモデルの構築と訓練プロセスを通じて、異なる閾値でのモデルの挙動を評価し、タスクの要件に最適な閾値を見つけることが重要です。

2
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?