1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

多項ロジスティック回帰の解釈について私的にまとめてみた

Posted at

はじめに

これは「imtakalab Advent Calendar 2023」の11日目の記事です。
https://adventar.org/calendars/9285

みなさん、こんにちは〜:frog:
日々の生活の物事の解釈に常々困っているそこらへんの大学4年生です。XやGitHubなどありますのでぜひフォローしていただけると嬉しいです:thumbsup:

ロジスティック回帰について

みなさん、多項ロジスティック回帰やっていますでしょうか?僕は最近になってめちゃくちゃやっています。多項ロジスティック回帰は以下のようなイメージが多いと思います。

  • クラスラベルを目的変数として、確率が高いものに割り振ることができるもの

ここで、2値であるときは所謂ロジスティック回帰になるわけですね。
式で表現すると、ロジスティック回帰は
$$ \pi(x) = \frac{\exp{(\beta_0 + \beta_1 x)}}{1 + \exp{(\beta_0 + \beta_1 x)}} $$
$E(Y|X) = \beta_0 + \beta_1 x$ としたとき、上記の $\pi(x)$ で表現されます。これをロジット変換と言います。1
$y$ で表現するなら、
$$ y = \pi(x) + \varepsilon \quad (\because \varepsilon は誤差項)$$
となります。

$\beta_0, \beta_1$ の推定値である$\widehat{\beta_0}, \widehat{\beta_1}$ は最尤法を用いて値を求めることができます。

また、説明変数が増えた場合は以下のような式になる。

$$ g(\boldsymbol{x}) = \ln\left( \frac{\pi(\boldsymbol{x})}{1-\pi(\boldsymbol{x})}\right) = \beta_0 + \beta_1 x_1 + \beta_2 x_2 + \dots + \beta_p x_p$$

上記の定義に基づくと、$\pi(\boldsymbol{x})$ は以下のように定義できる。

$$ \pi (\boldsymbol{x}) = \frac{\exp{\left(g(\boldsymbol{x})\right)}}{1 + \exp\left(g(\boldsymbol{x})\right)}$$

多項ロジスティック回帰

多項ロジスティック回帰では、$\pi(\boldsymbol{x})$ がsoftmax関数となります。クラスが3であるときは、以下の式となります。2

$$ \pi_{3}(\boldsymbol{x}) = \frac{\exp(\boldsymbol{x_3})}{\exp(\boldsymbol{x_1}) + \exp(\boldsymbol{x_2}) + \exp(\boldsymbol{x_3})}$$
同様に、最尤法で求められます。(今までの説明で間違っていたら、ご指摘していただけると幸いです。)

解釈方法

では、本題である解釈の仕方について解説します:pencil:
今回使うデータセットはデータサイエンスで有名なwineのデータ3を使用します。一応出力も一部載せておきますが、全部見たい方はこちらから参照してください :computer:

from sklearn.datasets import load_wine
wine = load_wine()
df = pd.DataFrame(wine.data, columns=wine.feature_names)
df['type'] = wine.target_names[wine.target]

データの中身をざっと見ていきましょう〜 :chart_with_upwards_trend:

df.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 178 entries, 0 to 177
Data columns (total 14 columns):
 #   Column                        Non-Null Count  Dtype  
---  ------                        --------------  -----  
 0   alcohol                       178 non-null    float64
 1   malic_acid                    178 non-null    float64
 2   ash                           178 non-null    float64
 3   alcalinity_of_ash             178 non-null    float64
 4   magnesium                     178 non-null    float64
 5   total_phenols                 178 non-null    float64
 6   flavanoids                    178 non-null    float64
 7   nonflavanoid_phenols          178 non-null    float64
 8   proanthocyanins               178 non-null    float64
 9   color_intensity               178 non-null    float64
 10  hue                           178 non-null    float64
 11  od280/od315_of_diluted_wines  178 non-null    float64
 12  proline                       178 non-null    float64
 13  type                          178 non-null    object 
dtypes: float64(13), object(1)
memory usage: 19.6+ KB
df.head()
alcohol malic_acid ash alcalinity_of_ash magnesium total_phenols flavanoids nonflavanoid_phenols proanthocyanins color_intensity hue od280/od315_of_diluted_wines proline type
0 14.23 1.71 2.43 15.6 127.0 2.80 3.06 0.28 2.29 5.64 1.04 3.92 1065.0 class_0
1 13.20 1.78 2.14 11.2 100.0 2.65 2.76 0.26 1.28 4.38 1.05 3.40 1050.0 class_0
2 13.16 2.36 2.67 18.6 101.0 2.80 3.24 0.30 2.81 5.68 1.03 3.17 1185.0 class_0
3 14.37 1.95 2.50 16.8 113.0 3.85 3.49 0.24 2.18 7.80 0.86 3.45 1480.0 class_0
4 13.24 2.59 2.87 21.0 118.0 2.80 2.69 0.39 1.82 4.32 1.04 2.93 735.0 class_0

type変数はすべて、class_0, class_1, class_2という3つのカテゴリで分かれています。

多重共線性を判断するVIFも見ていきましょう。以下がその数値です。10以上で多重共線性と一般的に言われています。

feature VIF
0 alcohol 206.189057
1 malic_acid 8.925541
2 ash 165.640370
3 alcalinity_of_ash 73.141564
4 magnesium 67.364868
5 total_phenols 62.786935
6 flavanoids 35.535602
7 nonflavanoid_phenols 16.636708
8 proanthocyanins 17.115665
9 color_intensity 17.022272
10 hue 45.398407
11 od280/od315_of_diluted_wines 54.539165
12 proline 16.370828

malic_acid 以外はかなり高い数値となっているため、今回の多項ロジスティックの解釈はあまり信頼性が無いかもです:head_bandage:
178行14列のデータフレームですので、もう少し大きなデータですと容易に行くかもしれません:eyes:

/usr/local/lib/python3.10/dist-packages/statsmodels/discrete/discrete_model.py:3025: RuntimeWarning: overflow encountered in exp
  eXB = np.column_stack((np.ones(len(X)), np.exp(X)))
/usr/local/lib/python3.10/dist-packages/statsmodels/discrete/discrete_model.py:3026: RuntimeWarning: invalid value encountered in divide
  return eXB/eXB.sum(1)[:,None]
Optimization terminated successfully.
         Current function value: nan
         Iterations 24
MNLogit Regression Results
Dep. Variable:	type	No. Observations:	178
Model:	MNLogit	Df Residuals:	150
Method:	MLE	Df Model:	26
Date:	Sun, 10 Dec 2023	Pseudo R-squ.:	nan
Time:	07:21:18	Log-Likelihood:	nan
converged:	True	LL-Null:	-193.31
Covariance Type:	nonrobust	LLR p-value:	nan
type=class_1	coef	std err	z	P>|z|	[0.025	0.975]
const	nan	nan	nan	nan	nan	nan
alcohol	nan	nan	nan	nan	nan	nan
malic_acid	nan	nan	nan	nan	nan	nan
ash	nan	nan	nan	nan	nan	nan
alcalinity_of_ash	nan	nan	nan	nan	nan	nan
magnesium	nan	nan	nan	nan	nan	nan
total_phenols	nan	nan	nan	nan	nan	nan
flavanoids	nan	nan	nan	nan	nan	nan
nonflavanoid_phenols	nan	nan	nan	nan	nan	nan
proanthocyanins	nan	nan	nan	nan	nan	nan
color_intensity	nan	nan	nan	nan	nan	nan
hue	nan	nan	nan	nan	nan	nan
od280/od315_of_diluted_wines	nan	nan	nan	nan	nan	nan
proline	nan	nan	nan	nan	nan	nan
type=class_2	coef	std err	z	P>|z|	[0.025	0.975]
const	nan	nan	nan	nan	nan	nan
alcohol	nan	nan	nan	nan	nan	nan
malic_acid	nan	nan	nan	nan	nan	nan
ash	nan	nan	nan	nan	nan	nan
alcalinity_of_ash	nan	nan	nan	nan	nan	nan
magnesium	nan	nan	nan	nan	nan	nan
total_phenols	nan	nan	nan	nan	nan	nan
flavanoids	nan	nan	nan	nan	nan	nan
nonflavanoid_phenols	nan	nan	nan	nan	nan	nan
proanthocyanins	nan	nan	nan	nan	nan	nan
color_intensity	nan	nan	nan	nan	nan	nan
hue	nan	nan	nan	nan	nan	nan
od280/od315_of_diluted_wines	nan	nan	nan	nan	nan	nan
proline	nan	nan	nan	nan	nan	nan

残念ながら、すべてnanになってしまったので、代わりにirisデータ4を用いましょう…

irisも同様に、3クラスのラベルがあります。
そのクラスラベルを目的変数として、多項ロジスティック回帰をしてみましょう〜:chart_with_upwards_trend:

Warning: Maximum number of iterations has been exceeded.
         Current function value: 0.039662
         Iterations: 35
/usr/local/lib/python3.10/dist-packages/statsmodels/base/model.py:607: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals
  warnings.warn("Maximum Likelihood optimization failed to "
MNLogit Regression Results
Dep. Variable:	species	No. Observations:	150
Model:	MNLogit	Df Residuals:	140
Method:	MLE	Df Model:	8
Date:	Sun, 10 Dec 2023	Pseudo R-squ.:	0.9639
Time:	07:27:25	Log-Likelihood:	-5.9493
converged:	False	LL-Null:	-164.79
Covariance Type:	nonrobust	LLR p-value:	7.055e-64
species=versicolor	coef	std err	z	P>|z|	[0.025	0.975]
const	32.6240	3.87e+05	8.42e-05	1.000	-7.59e+05	7.59e+05
sepal length (cm)	-2.6491	3.42e+05	-7.74e-06	1.000	-6.71e+05	6.71e+05
sepal width (cm)	-2.1068	1.83e+05	-1.15e-05	1.000	-3.6e+05	3.6e+05
petal length (cm)	28.9332	5.68e+05	5.09e-05	1.000	-1.11e+06	1.11e+06
petal width (cm)	25.0304	4.11e+05	6.09e-05	1.000	-8.05e+05	8.05e+05
species=virginica	coef	std err	z	P>|z|	[0.025	0.975]
const	12.5222	3.87e+05	3.23e-05	1.000	-7.59e+05	7.59e+05
sepal length (cm)	-4.6837	3.42e+05	-1.37e-05	1.000	-6.71e+05	6.71e+05
sepal width (cm)	-5.0090	1.83e+05	-2.73e-05	1.000	-3.6e+05	3.6e+05
petal length (cm)	45.5233	5.68e+05	8.01e-05	1.000	-1.11e+06	1.11e+06
petal width (cm)	38.9222	4.11e+05	9.47e-05	1.000	-8.05e+05	8.05e+05

3クラスなのに、なんで2種類しかないの?という疑問に関しては、species=setosaを基準としているためです。ですので、上記の出力結果はクラスラベル数-1の数が出力されるわけです。
上記の理由により、解釈の仕方も「setosaの特定の説明変数よりも」大きいor小さいということになります。

例でいうと、setosaよりもversicolor, virginicaでのpetal length(cm)petal wigth(cm)は大きいということが分かります。注意点として、多項ロジスティック回帰(ロジスティック回帰も同様)での偏回帰係数の推定値は、オッズ比となりますので、そのままのスケール(単位)での解釈ができない点に注意です。

また、今回の検定の数値は標準誤差(std err)がかなり数値がデカくなってしまっているため、信頼性の薄い値となっています。解釈するときは注意しましょう。

おわりに

いかがでしたでしょうか? :shipit:
今回はstatsmodelのほうで統計量を出力しました。一応scikit-learnのほうでもできますが、アプローチがそもそも違うので使い方には注意が必要です。
多項ロジスティック回帰の解釈はPythonよりもR言語のほうが、インターネットの記事が多いですのでそちらも参照していただけると幸いです:pray:

ではでは〜 👋

  1. CiNii 図書 - データ解析のためのロジスティック回帰モデル

  2. ロジスティック回帰を多クラス分類に応用する【機械学習入門17】

  3. sklearn.datasets.load_wine — scikit-learn 1.3.2 documentation

  4. sklearn.datasets.load_iris — scikit-learn 1.3.2 documentation

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?