LoginSignup
0
0

Pythonで主成分分析のbiplot関数を作ってみた

Posted at

biplot関数とはなんぞやかと言いますと、主成分分析をやって2次元に圧縮した散布図に因子負荷量のベクトルも一緒にプロットする散布図です。

関数(教師あり)

from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
def biplot_spv(df, y):
    for col in df.columns:
        df[col] = (df[col] - df[col].mean()) / df[col].std()
    model = PCA()
    model.fit(df)
    df_pc = model.transform(df)
    com = model.components_
    evr = model.explained_variance_ratio_
    fac = []
    for i in range(len(evr)):
        fac.append(np.sqrt(model.explained_variance_ratio_)[i] * model.components_[i])
    fig, ax1 = plt.subplots()
    ax2 = ax1.twinx()
    ax2.scatter(df_pc[:, 0], df_pc[:, 1], cmap="brg", c=y)
    ax3 = ax1.twiny()
    ylim = [abs(max(fac[1])), abs(min(fac[1]))]
    xlim = [abs(max(fac[0])), abs(min(fac[0]))]
    for i in range(len(df.columns)):
        ax3.plot([0, fac[0][i]], [0, fac[1][i]], color="#FF0000")
        ax3.text(fac[0][i], fac[1][i], df.columns[i])
    ax3.set_xlim(-max(xlim), max(xlim))
    ax3.set_ylim(-max(ylim), max(ylim))
    plt.show()

使用例

import pandas as pd
df = pd.read_csv("iris.csv")
biplot_PCA(df.drop("category", axis=1), y=df["category"])

Untitled.png

関数(教師なし)

from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
def biplot_nonspv(df):
    for col in df.columns:
        df[col] = (df[col] - df[col].mean()) / df[col].std()
    model = PCA()
    model.fit(df)
    df_pc = model.transform(df)
    com = model.components_
    evr = model.explained_variance_ratio_
    fac = []
    for i in range(len(evr)):
        fac.append(np.sqrt(model.explained_variance_ratio_)[i] * model.components_[i])
    fig, ax1 = plt.subplots()
    ax2 = ax1.twinx()
    ax2.scatter(df_pc[:, 0], df_pc[:, 1], color="#000000")
    ax3 = ax1.twiny()
    ylim = [abs(max(fac[1])), abs(min(fac[1]))]
    xlim = [abs(max(fac[0])), abs(min(fac[0]))]
    for i in range(len(df.columns)):
        ax3.plot([0, fac[0][i]], [0, fac[1][i]], color="#FF0000")
        ax3.text(fac[0][i], fac[1][i], df.columns[i])
    ax3.set_xlim(-max(xlim), max(xlim))
    ax3.set_ylim(-max(ylim), max(ylim))
    plt.show()

使用例

import pandas as pd
df = pd.read_csv("seiseki.csv")
biplot_nonspv(df)

Untitled.png

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0