3
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ロジスティック回帰のすゝめ

Last updated at Posted at 2024-01-21

はじめに

こんにちは!
僕は研究を行いながら、長期インターンでデータサイエンティストとして働く大学院生です!

学部時代から長期インターンを始め、現在4社経験してきました。
この経験から、プログラミングの学習を始めたばかりの人や、長期インターンを行う勇気が出ない人に、学習サポートやデータ分析の実績作り支援などを行わせてもらっています!

僕自身、プログラミングの習得や長期インターン探しに苦労したので、その経験をお伝えすることで、より多くの人が挫折せずデータサイエンティストになるまで成長して欲しいです!

以下でサポートを行なっているのでご興味ある方はご連絡ください!学生・社会人問わず専攻も問わずサポートいたします!

X(Twitter)

今回は、データサイエンスを学ぶ上で欠かせない分類モデルロジスティック回帰手法について解説します👍

ロジスティック回帰とは

ロジスティック回帰は、特定のカテゴリ(例えば、0か1)に属する確率を予測するための統計モデルです。このモデルは、特にバイナリ分類問題に適しています。

数学的定式化

ロジスティック回帰は、以下の式で表されます。

P(Y=1|X)=\frac{1}{1+e^{-βX}}
P(Y=1|X)は、特徴量Xが与えられた際のYが1になる確率
βはモデルのパラメータ

シグモイド関数

シグモイド関数は、ロジスティック回帰における中核的な要素です。この関数は、実数値を0と1の間の値にマッピングするために使用され、結果を確率として解釈できるようにします。

シグモイド関数の定義

シグモイド関数は次の式で定義されます。

f(x)=\frac{1}{1+e^{-x}}

ここで、$x$は入力(例えば、線形回帰モデルの出力)です。

オッズとロジット変換

シグモイド関数を理解するためには、オッズとロジット変換の概念が重要です。

  • オッズ: ある事象が起こる確率と起こらない確率の比です。事象が起こる確率を$p$とした場合以下で表される
\frac{p}{1-p}
  • ロジット変換: オッズの自然対数を取ることです。ロジット関数は確率$p$をロジットスケールに変換したもので以下で表される
\log\Big(\frac{p}{1-p}\Big)

シグモイド関数は、実質的にロジット変換の逆関数です。ロジット関数はオッズを対数スケールに変換しますが、シグモイド関数はその逆を行い、ロジットスケールの値を確率に変換します。

シグモイド関数の役割

シグモイド関数の主な役割は、線形モデルの出力(ロジット)を0と1の間の確率に変換することです。これにより、ロジスティック回帰はバイナリ分類問題に対する確率的アプローチを提供します。

誤差関数

ロジスティック回帰モデルの学習において、誤差関数(または損失関数)は非常に重要です。この関数は、モデルの予測と実際のデータとの差異を評価します。

交差エントロピー誤差

ロジスティック回帰では、一般的に交差エントロピー損失が使用されます。この損失関数は、モデルの予測が実際のラベルからどれだけ離れているかを測定します。数式は以下のように表されます。

L(β)=-\sum_{i=1}^{n}\Big[y_i\log(p_i)+(1-y_i)\log(1-p_i)\Big]

ここで

  • $n$はデータ数
  • $y_i$は$i$番目のデータポイントの実際のラベル(0または1)
  • $p_i$はモデルによって予測された$i$番目のデータポイントの確率
  • $L(\beta)$はモデルのパラメータ$\beta$に対する全データポイントの合計損失

損失関数の解釈

  • $y_i=1$の場合: 式は$-\log(p_i)$となり、予測確率$p_i$が1に近づくにつれて損失は減少する
  • $y_i=0$の場合: 式は$-\log(1-p_i)$となり、予測確率$p_i$が$0$に近づくにつれて損失は減少する

つまり、モデルの予測が正確であればあるほど、損失は小さくなります。逆に、予測が実際のラベルと大きく異なる場合、損失は大きくなります。

最適化アルゴリズム

ロジスティック回帰モデルの最適化には、勾配降下法が一般的に使用されます。この方法は、モデルの誤差を最小限にするようにパラメータを逐次的に更新するものです。

勾配降下法の基本

勾配降下法では、誤差関数$L(\beta)$の勾配(偏微分係数)を計算し、その勾配に沿ってパラメータを更新します。パラメータの更新式は以下のようになります。

\displaylines{
β_j=β_j-α\frac{{\partial}L}{{\partial}{β_j}} \\
β_j=(更新されるパラメータ) \\
α=(学習率) \\
\frac{{\partial}L}{{\partial}{β_j}}=(β_jに関する誤差関数の偏微分)
}

勾配の計算

ロジスティック回帰の誤差関数の偏微分は次のように計算されます。

\frac{{\partial}L}{{\partial}{β_j}}=\sum_{i=1}^n(p_i-y_i)x_{ij}

ここで、$p_i$はモデルによって予測された確率、$y_i$は$i$番目のデータポイントの実際のラベル、$x_{ij}$は$i$番目のデータポイントの$j$番目の特徴量です。

注意点

ロジスティック回帰を使用する際には、特に以下の2つの点に注意する必要があります:偏回帰係数の可視化と多重共線性。

偏回帰係数の可視化

  • 偏回帰係数は、各特徴量が目的変数に与える影響の大きさを示します
  • 係数の大きさと符号は、その特徴量が結果に与える影響の方向と強さを示します
  • 係数をバーチャート等で可視化することで、どの特徴量が予測に最も寄与しているかを直感的に理解できます
  • 係数の信頼区間を表示することで、その推定値の不確実性を把握することも重要です
  • モデルの解釈性を高め、よりデータドリブンな意思決定をサポートします
  • ビジネスや研究のコンテキストにおいて、どの変数が重要かを明確にすることができます

多重共線性

  • 多重共線性は、モデルの説明変数が高度に相関している場合に生じる問題です
  • これにより、個々の変数の影響を正確に推定することが困難になります
  • 多重共線性があると、偏回帰係数の推定が不安定になり、過学習を引き起こす可能性があります
  • 係数の解釈が難しくなり、モデルの信頼性が低下します。
  • 相関の高い変数の削除や統合で対処
  • 主成分分析(PCA)などの次元削減技術を使用して変数の線形独立性を確保できる
  • 正則化手法(例:リッジ回帰、ラッソ回帰)を使用して多重共線性の影響を緩和することも可能

実装例

ここでは、Pythonのscikit-learnライブラリを用いたロジスティック回帰の実装例を示します。データセットはとても有名なアヤメのデータセットを用います。

import matplotlib.pyplot as plt
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, accuracy_score
from sklearn.datasets import load_iris

# データセットの読み込み
data = load_iris()
X = data.data
y = (data.target == 2).astype(int)  # 例として2値分類を行う

# データの分割
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

# モデルの作成と学習
model = LogisticRegression()
model.fit(X_train, y_train)

# テストデータに対する予測
predictions = model.predict(X_test)

# 予測精度の評価
accuracy = accuracy_score(y_test, predictions)
print(f"Accuracy: {accuracy:.2f}")

# 詳細な分類レポート
print(classification_report(y_test, predictions))

# 係数の可視化
coefficients = model.coef_[0]
features = data.feature_names
plt.bar(features, coefficients)
plt.xticks(rotation=45)
plt.xlabel('Features')
plt.ylabel('Coefficients')
plt.title('Visualization of Coefficients')
plt.show()
  1. モデルを訓練データで学習させる
  2. テストデータでモデルの予測を行う
  3. accuracy_scoreを使用して予測の正確性を測定する
  4. classification_reportで精度、リコール、F1スコアなどの詳細な指標を表示する
  5. 偏回帰係数をバーチャートで可視化

これにより、モデルの予測精度を正確に評価し、モデルの解釈を容易にすることが可能です。

さいごに

最後まで読んでいただきありがとうございました!
少しでもデータサイエンティストを目指す方の一助となればと思います。

もし僕の活動にもご興味を持っていただけたら、X(Twitter)もフォローしていただけると嬉しいです!

X(Twitter)

参考文献

3
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?