内容
ロジスティクス回帰を用いて分類を行うモデルを構築します。
回帰とついてますが、分類モデルとしても使用できるので、
分類モデルの内容をまとめます。
実行内容
まずはライブラリをインポートします。
# ライブラリインポート
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.datasets import make_classification
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
次に学習に使用するデータを読み込みます。
今回もirisデータを読み込んで分類を行います。
# データ読み込み
from sklearn.datasets import load_iris
iris = load_iris()
# データフレームを作成
data = pd.DataFrame(iris.data, columns=iris.feature_names)
# ターゲットデータを追加
data['target'] = iris.target
# (本来、irisデータはデータフレーム化まで不要ですが備忘録のために作成)
作成したデータセットをtrainデータとtestデータに分割、
標準化まで実施します。
# データセットを作成
x = data.drop('target' , axis = 1)
y = data['target']
# trainデータとtestデータに分割
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=42)
# 標準化
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train = scaler.fit_transform(x_train)
X_test = scaler.transform(x_test)
これで標準化まで完了しました。
これで学習するデータが整いましたので、学習を行います。
# インスタンス化
model = LogisticRegression()
# 学習
model.fit(x_train, y_train)
# 予測と評価
y_pred = model.predict(x_test)
accuracy = accuracy_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)
report = classification_report(y_test, y_pred)
print(f'Accuracy: {accuracy:.2f}')
print('Confusion Matrix:\n', conf_matrix)
print('Classification Report:\n', report)
結果は下記の結果が得られました。
精度が1.00なので分類できていることがわかります。
Accuracy: 1.00
Confusion Matrix:
[[10 0 0]
[ 0 9 0]
[ 0 0 11]]
Classification Report:
precision recall f1-score support
0 1.00 1.00 1.00 10
1 1.00 1.00 1.00 9
2 1.00 1.00 1.00 11
accuracy 1.00 30
macro avg 1.00 1.00 1.00 30
weighted avg 1.00 1.00 1.00 30
ロジスティクス回帰では重みを確認することで、
それぞれのクラスにどの要素がどのくらい影響するか可視化することが可能です。
この要素の可視化が一番の目的かと思うので、
可視化のコードをまとめておきます。
# プロット
plt.figure(figsize=(10, 6))
for i in range(len(model.coef_)):
plt.subplot(2, 3, i + 1)
plt.bar(x=iris.feature_names, height=model.coef_[i])
plt.title(iris.target_names[i])
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
実際のデータに合わせてグラフのXやタイトルを適宜合わせれば仕様できると思います。
回帰なのに分類できる理由
備忘録もかねて、回帰ですが分類できる理由も合わせて記載します。
まず最初に線形回帰を行います。
z = w_1x_1 + w_2x_2 + ... + w_nx_n + b
この値をシグモイド関数に使うことで 0~1 の値に変換します。
その際に閾値(通常:0.5)を定めることで閾値に対する判定により分類を行います。
また各係数は上記の $z$ を求める際に使用した係数が重みとなって出力されています。
sigmoid(z)=\frac{1}{1+{e}^{-az}}