概要
Python + pandas + matplotlib で 相関行列(各変数間の相関係数を行列にしたもの)から、きれいに体裁を整えた ヒートマップ を作成していきます。
ここでは、例題として、次のような5科目成績の相関行列についてヒートマップを作成してみたいと思います。
実行環境
Google Colab.(Python 3.6.9)で実行・動作確認をしています。ほぼ Jupyter Notebook と同じです。
!pip list
matplotlib 3.1.2
numpy 1.17.4
pandas 0.25.3
matplotlibで日本語を使うための準備
matplotlib の出力図のなかで、日本語が使えるようにします。
!pip install japanize-matplotlib
import japanize_matplotlib
以上により、japanize-matplotlib-1.0.5
がインストール、インポートされて、ラベル等に日本語を使っても文字化け(豆腐化)しなくなります。
相関行列を求めて、とりあえずヒートマップ化
相関行列は、pandas の機能で簡単に求めることができます。
import pandas as pd
# ダミーデータ
国語 = [76, 62, 71, 85, 96, 71, 68, 52, 85, 91]
社会 = [71, 85, 64, 55, 79, 72, 73, 52, 84, 84]
数学 = [50, 78, 48, 64, 66, 62, 58, 50, 50, 60]
理科 = [37, 90, 45, 56, 59, 56, 84, 86, 51, 61]
英語 = [59, 97, 71, 85, 58, 82, 70, 61, 79, 70]
df = pd.DataFrame( {'国語':国語, '社会':社会, '数学':数学, '理科':理科, '英語':英語} )
# 相関係数を計算
df2 = df.corr()
display(df2)
行列の各要素は、$-1.0$ から $1.0$ の範囲の値をとります。この値が、$1.0$ に近いほど正の相関があり、$-1.0$ に近いほど負の相関があると判断します。$-0.2$ ~ $0.2$ の範囲では、**相関がない(無相関)**と判断します。
なお、対角要素は、同項目同士の相関係数なので $1.0$(=完全な正の相関がある)になります。
上で示したように相関係数を数値としてならべても、全体の把握が難しいので、ヒートマップを使って可視化してみます。
まずは、体裁の調整などは抜いて必要最低限のコードでヒートマップを作成してみます。
%reset -f
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors
# ダミーデータ
国語 = [76, 62, 71, 85, 96, 71, 68, 52, 85, 91]
社会 = [71, 85, 64, 55, 79, 72, 73, 52, 84, 84]
数学 = [50, 78, 48, 64, 66, 62, 58, 50, 50, 60]
理科 = [37, 90, 45, 56, 59, 56, 84, 86, 51, 61]
英語 = [59, 97, 71, 85, 58, 82, 70, 61, 79, 70]
df = pd.DataFrame( {'国語':国語, '社会':社会, '数学':数学, '理科':理科, '英語':英語} )
# 相関係数を計算
df2 = df.corr()
display(df2)
# 相関係数の行列をヒートマップで出力
plt.figure(dpi=120)
plt.imshow(df2,interpolation='nearest',vmin=-1.0,vmax=1.0)
plt.colorbar()
# 軸に項目名(国語・社会・数学・理科・英語)を出力する設定
n = len(df2.columns) # 項目数
plt.gca().set_xticks(range(n))
plt.gca().set_xticklabels(df2.columns)
plt.gca().set_yticks(range(n))
plt.gca().set_yticklabels(df2.columns)
実行結果
次のような出力を得ることができます。右側のカラーバーをもとに、紫・青の暗めの色がついているマスのところに負の相関があり、黄・緑の明るめの色がついているところに正の相関があると読み取っていきます。
正直、デフォルト設定のままでは、分かりやすいヒートマップは作成できません。
体裁を整えて美しく出力
美しく直感的にも分かりやすいヒートマップを得るためのカスタマイズを施していきます。主なポイントは、次の通りです。
- 対角成分のマスについては白色にして斜線を引く。
- カラーマップをカスタマイズして、無相関の範囲では白色になるようにする。
- グリッドを挿入する(マスとマスの間に白色の線を引く)。
- 相関係数値をマス上に印字する。
- 背景色と重なってもきれいに見えるように縁取りをする。
コード化すると次のようになります。
%reset -f
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patheffects as path_effects
import matplotlib.ticker as ticker
import matplotlib.colors
# ダミーデータ
国語 = [76, 62, 71, 85, 96, 71, 68, 52, 85, 91]
社会 = [71, 85, 64, 55, 79, 72, 73, 52, 84, 84]
数学 = [50, 78, 48, 64, 66, 62, 58, 50, 50, 60]
理科 = [37, 90, 45, 56, 59, 56, 84, 86, 51, 61]
英語 = [59, 97, 71, 85, 58, 82, 70, 61, 79, 70]
df = pd.DataFrame( {'国語':国語, '社会':社会, '数学':数学, '理科':理科, '英語':英語} )
# 相関係数を計算
df2 = df.corr()
for i in df2.index.values :
df2.at[i,i] = 0.0
# 相関係数の行列をヒートマップで出力
plt.figure(dpi=120)
# カスタムカラーマップ
cl = list()
cl.append( ( 0.00, matplotlib.colors.hsv_to_rgb((0.6, 1. ,1))) )
cl.append( ( 0.30, matplotlib.colors.hsv_to_rgb((0.6, 0.1 ,1))) )
cl.append( ( 0.50, matplotlib.colors.hsv_to_rgb((0.3, 0. ,1))) )
cl.append( ( 0.70, matplotlib.colors.hsv_to_rgb((0.0, 0.1 ,1))) )
cl.append( ( 1.00, matplotlib.colors.hsv_to_rgb((0.0, 1. ,1))) )
ccm = matplotlib.colors.LinearSegmentedColormap.from_list('custom_cmap', cl)
plt.imshow(df2,interpolation='nearest',vmin=-1.0,vmax=1.0,cmap=ccm)
# 左側に表示するカラーバーの設定
fmt = lambda p, pos=None : f'${p:+.1f}$' if p!=0 else ' $0.0$'
cb = plt.colorbar(format=ticker.FuncFormatter(fmt))
cb.set_label('相関係数', fontsize=11)
# 項目(国語・社会・数学・理科・英語)の出力に関する設定
n = len(df2.columns) # 項目数
plt.gca().set_xticks(range(n))
plt.gca().set_xticklabels(df.columns)
plt.gca().set_yticks(range(n))
plt.gca().set_yticklabels(df.columns)
plt.tick_params(axis='x', which='both', direction=None,
top=True, bottom=False, labeltop=True, labelbottom=False)
plt.tick_params(axis='both', which='both', top=False, left=False )
# グリッドに関する設定
plt.gca().set_xticks(np.arange(-0.5, n-1), minor=True);
plt.gca().set_yticks(np.arange(-0.5, n-1), minor=True);
plt.grid( which='minor', color='white', linewidth=1)
# 斜線
plt.plot([-0.5,n-0.5],[-0.5,n-0.5],color='black',linewidth=0.75)
# 相関係数を表示(文字に縁取り付き)
tp = dict(horizontalalignment='center',verticalalignment='center')
ep = [path_effects.Stroke(linewidth=3, foreground='white'),path_effects.Normal()]
for y,i in enumerate(df2.index.values) :
for x,c in enumerate(df2.columns.values) :
if x != y :
t = plt.text(x, y, f'{df2.at[i,c]:.2f}',**tp)
t.set_path_effects(ep)