これは「imtakalab Advent Calendar 2023」の11日目の記事です。




  • クラスラベルを目的変数として、確率が高いものに割り振ることができるもの

$$ \pi(x) = \frac{\exp{(\beta_0 + \beta_1 x)}}{1 + \exp{(\beta_0 + \beta_1 x)}} $$
$E(Y|X) = \beta_0 + \beta_1 x$ としたとき、上記の $\pi(x)$ で表現されます。これをロジット変換と言います。1
$y$ で表現するなら、
$$ y = \pi(x) + \varepsilon \quad (\because \varepsilon は誤差項)$$

$\beta_0, \beta_1$ の推定値である$\widehat{\beta_0}, \widehat{\beta_1}$ は最尤法を用いて値を求めることができます。


$$ g(\boldsymbol{x}) = \ln\left( \frac{\pi(\boldsymbol{x})}{1-\pi(\boldsymbol{x})}\right) = \beta_0 + \beta_1 x_1 + \beta_2 x_2 + \dots + \beta_p x_p$$

上記の定義に基づくと、$\pi(\boldsymbol{x})$ は以下のように定義できる。

$$ \pi (\boldsymbol{x}) = \frac{\exp{\left(g(\boldsymbol{x})\right)}}{1 + \exp\left(g(\boldsymbol{x})\right)}$$


多項ロジスティック回帰では、$\pi(\boldsymbol{x})$ がsoftmax関数となります。クラスが3であるときは、以下の式となります。2

$$ \pi_{3}(\boldsymbol{x}) = \frac{\exp(\boldsymbol{x_3})}{\exp(\boldsymbol{x_1}) + \exp(\boldsymbol{x_2}) + \exp(\boldsymbol{x_3})}$$


今回使うデータセットはデータサイエンスで有名なwineのデータ3を使用します。一応出力も一部載せておきますが、全部見たい方はこちらから参照してください :computer:

from sklearn.datasets import load_wine
wine = load_wine()
df = pd.DataFrame(wine.data, columns=wine.feature_names)
df['type'] = wine.target_names[wine.target]

データの中身をざっと見ていきましょう〜 :chart_with_upwards_trend:

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 178 entries, 0 to 177
Data columns (total 14 columns):
 #   Column                        Non-Null Count  Dtype  
---  ------                        --------------  -----  
 0   alcohol                       178 non-null    float64
 1   malic_acid                    178 non-null    float64
 2   ash                           178 non-null    float64
 3   alcalinity_of_ash             178 non-null    float64
 4   magnesium                     178 non-null    float64
 5   total_phenols                 178 non-null    float64
 6   flavanoids                    178 non-null    float64
 7   nonflavanoid_phenols          178 non-null    float64
 8   proanthocyanins               178 non-null    float64
 9   color_intensity               178 non-null    float64
 10  hue                           178 non-null    float64
 11  od280/od315_of_diluted_wines  178 non-null    float64
 12  proline                       178 non-null    float64
 13  type                          178 non-null    object 
dtypes: float64(13), object(1)
memory usage: 19.6+ KB
alcohol malic_acid ash alcalinity_of_ash magnesium total_phenols flavanoids nonflavanoid_phenols proanthocyanins color_intensity hue od280/od315_of_diluted_wines proline type
0 14.23 1.71 2.43 15.6 127.0 2.80 3.06 0.28 2.29 5.64 1.04 3.92 1065.0 class_0
1 13.20 1.78 2.14 11.2 100.0 2.65 2.76 0.26 1.28 4.38 1.05 3.40 1050.0 class_0
2 13.16 2.36 2.67 18.6 101.0 2.80 3.24 0.30 2.81 5.68 1.03 3.17 1185.0 class_0
3 14.37 1.95 2.50 16.8 113.0 3.85 3.49 0.24 2.18 7.80 0.86 3.45 1480.0 class_0
4 13.24 2.59 2.87 21.0 118.0 2.80 2.69 0.39 1.82 4.32 1.04 2.93 735.0 class_0

type変数はすべて、class_0, class_1, class_2という3つのカテゴリで分かれています。


feature VIF
0 alcohol 206.189057
1 malic_acid 8.925541
2 ash 165.640370
3 alcalinity_of_ash 73.141564
4 magnesium 67.364868
5 total_phenols 62.786935
6 flavanoids 35.535602
7 nonflavanoid_phenols 16.636708
8 proanthocyanins 17.115665
9 color_intensity 17.022272
10 hue 45.398407
11 od280/od315_of_diluted_wines 54.539165
12 proline 16.370828

malic_acid 以外はかなり高い数値となっているため、今回の多項ロジスティックの解釈はあまり信頼性が無いかもです:head_bandage:

/usr/local/lib/python3.10/dist-packages/statsmodels/discrete/discrete_model.py:3025: RuntimeWarning: overflow encountered in exp
  eXB = np.column_stack((np.ones(len(X)), np.exp(X)))
/usr/local/lib/python3.10/dist-packages/statsmodels/discrete/discrete_model.py:3026: RuntimeWarning: invalid value encountered in divide
  return eXB/eXB.sum(1)[:,None]
Optimization terminated successfully.
         Current function value: nan
         Iterations 24
MNLogit Regression Results
Dep. Variable:	type	No. Observations:	178
Model:	MNLogit	Df Residuals:	150
Method:	MLE	Df Model:	26
Date:	Sun, 10 Dec 2023	Pseudo R-squ.:	nan
Time:	07:21:18	Log-Likelihood:	nan
converged:	True	LL-Null:	-193.31
Covariance Type:	nonrobust	LLR p-value:	nan
type=class_1	coef	std err	z	P>|z|	[0.025	0.975]
const	nan	nan	nan	nan	nan	nan
alcohol	nan	nan	nan	nan	nan	nan
malic_acid	nan	nan	nan	nan	nan	nan
ash	nan	nan	nan	nan	nan	nan
alcalinity_of_ash	nan	nan	nan	nan	nan	nan
magnesium	nan	nan	nan	nan	nan	nan
total_phenols	nan	nan	nan	nan	nan	nan
flavanoids	nan	nan	nan	nan	nan	nan
nonflavanoid_phenols	nan	nan	nan	nan	nan	nan
proanthocyanins	nan	nan	nan	nan	nan	nan
color_intensity	nan	nan	nan	nan	nan	nan
hue	nan	nan	nan	nan	nan	nan
od280/od315_of_diluted_wines	nan	nan	nan	nan	nan	nan
proline	nan	nan	nan	nan	nan	nan
type=class_2	coef	std err	z	P>|z|	[0.025	0.975]
const	nan	nan	nan	nan	nan	nan
alcohol	nan	nan	nan	nan	nan	nan
malic_acid	nan	nan	nan	nan	nan	nan
ash	nan	nan	nan	nan	nan	nan
alcalinity_of_ash	nan	nan	nan	nan	nan	nan
magnesium	nan	nan	nan	nan	nan	nan
total_phenols	nan	nan	nan	nan	nan	nan
flavanoids	nan	nan	nan	nan	nan	nan
nonflavanoid_phenols	nan	nan	nan	nan	nan	nan
proanthocyanins	nan	nan	nan	nan	nan	nan
color_intensity	nan	nan	nan	nan	nan	nan
hue	nan	nan	nan	nan	nan	nan
od280/od315_of_diluted_wines	nan	nan	nan	nan	nan	nan
proline	nan	nan	nan	nan	nan	nan



Warning: Maximum number of iterations has been exceeded.
         Current function value: 0.039662
         Iterations: 35
/usr/local/lib/python3.10/dist-packages/statsmodels/base/model.py:607: ConvergenceWarning: Maximum Likelihood optimization failed to converge. Check mle_retvals
  warnings.warn("Maximum Likelihood optimization failed to "
MNLogit Regression Results
Dep. Variable:	species	No. Observations:	150
Model:	MNLogit	Df Residuals:	140
Method:	MLE	Df Model:	8
Date:	Sun, 10 Dec 2023	Pseudo R-squ.:	0.9639
Time:	07:27:25	Log-Likelihood:	-5.9493
converged:	False	LL-Null:	-164.79
Covariance Type:	nonrobust	LLR p-value:	7.055e-64
species=versicolor	coef	std err	z	P>|z|	[0.025	0.975]
const	32.6240	3.87e+05	8.42e-05	1.000	-7.59e+05	7.59e+05
sepal length (cm)	-2.6491	3.42e+05	-7.74e-06	1.000	-6.71e+05	6.71e+05
sepal width (cm)	-2.1068	1.83e+05	-1.15e-05	1.000	-3.6e+05	3.6e+05
petal length (cm)	28.9332	5.68e+05	5.09e-05	1.000	-1.11e+06	1.11e+06
petal width (cm)	25.0304	4.11e+05	6.09e-05	1.000	-8.05e+05	8.05e+05
species=virginica	coef	std err	z	P>|z|	[0.025	0.975]
const	12.5222	3.87e+05	3.23e-05	1.000	-7.59e+05	7.59e+05
sepal length (cm)	-4.6837	3.42e+05	-1.37e-05	1.000	-6.71e+05	6.71e+05
sepal width (cm)	-5.0090	1.83e+05	-2.73e-05	1.000	-3.6e+05	3.6e+05
petal length (cm)	45.5233	5.68e+05	8.01e-05	1.000	-1.11e+06	1.11e+06
petal width (cm)	38.9222	4.11e+05	9.47e-05	1.000	-8.05e+05	8.05e+05


例でいうと、setosaよりもversicolor, virginicaでのpetal length(cm)petal wigth(cm)は大きいということが分かります。注意点として、多項ロジスティック回帰(ロジスティック回帰も同様)での偏回帰係数の推定値は、オッズ比となりますので、そのままのスケール(単位)での解釈ができない点に注意です。

また、今回の検定の数値は標準誤差(std err)がかなり数値がデカくなってしまっているため、信頼性の薄い値となっています。解釈するときは注意しましょう。


いかがでしたでしょうか? :shipit:

ではでは〜 👋

  1. CiNii 図書 - データ解析のためのロジスティック回帰モデル

  2. ロジスティック回帰を多クラス分類に応用する【機械学習入門17】

  3. sklearn.datasets.load_wine — scikit-learn 1.3.2 documentation

  4. sklearn.datasets.load_iris — scikit-learn 1.3.2 documentation


