LoginSignup
1
1

More than 3 years have passed since last update.

Seabornを使って分類を可視化してみた

Last updated at Posted at 2021-03-31

はじめに

初めての投稿。下記の内容を参考にJupyter Notebookで可視化をしてみた。訓練データで分類したものがテストデータにどのくらい合致しているか目で見れるようにした。

環境

Windows10
Python 3.8.5

準備

#ライブラリインポート
import matplotlib.pyplot as plt
import seaborn as sns
from matplotlib.colors import ListedColormap
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_breast_cancer
from sklearn.svm import LinearSVC

##データセットの用意
data = load_breast_cancer()
df = pd.DataFrame(data.data, columns = data.feature_names)
df.head()

Dataframeの表示。今回グラフ描写には1番目と2番目の要素を使う。
image.png

分類の可視化

#データの設定
X = data.data[:,:2]
y = data.target

# ランダムにトレーニングデータとテストデータとに分割
x_train, x_test, y_train, y_test = train_test_split(X, y, test_size=0.25,random_state=5)

# マップの色(青と赤で指定)
cmap_light = ListedColormap(['b', 'r'])
cmap_bold = ['b', 'r',]

#モデル化
svc = LinearSVC()
clf = svc
clf.fit(x_train, y_train) #訓練データを使って区分化

# メッシュサイズ
h = .02 

# プロットの境界を決める
# point in the mesh [x_min, x_max]x[y_min, y_max].
x_min, x_max = x_train[:, 0].min() - 1, x_train[:, 0].max() + 1
y_min, y_max = x_train[:, 1].min() - 1, x_train[:, 1].max() + 1
xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                         np.arange(y_min, y_max, h))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])

# 実際にプロットし、どのように分類されるか定義
Z = Z.reshape(xx.shape)
plt.figure(figsize=(7, 5))
plt.contourf(xx, yy, Z, cmap=cmap_light,alpha=0.5)

# 描画
sns.scatterplot(data=data, x=x_train[:, 0], y=x_train[:, 1],hue=y_train,
                    palette=cmap_bold, alpha=0.8, edgecolor="black",marker='o',s=50) #訓練データ
sns.scatterplot(x=x_test[:, 0], y=x_test[:, 1],
                    palette=cmap_bold, alpha=0.8, edgecolor='yellow',linewidth=1,marker='*',s=250,legend=False) #テストデータ
plt.legend(bbox_to_anchor=(1.2,1),loc='upper right',borderaxespad=0) #凡例をグラフの欄外に出す
plt.xlim(xx.min(), xx.max())
plt.ylim(yy.min(), yy.max())
plt.xlabel('mean radius')
plt.ylabel('mean texture')

plt.show()

黒の枠線で囲われたのが訓練データ、黄色の枠線で囲われたのがテストデータ。
image.png

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1