はじめに
データ分析する際にSeabornを使用するのですが、グラフを並べて表示したいと思いました。
matplotlib
ではplt.subplots
を使用して1つのfig
に複数のax
を入れることができ、seabornでも同様にしてできそうかと思ったのですが、catplot
だけは簡単にできなかったので備忘録として書きます。
まずは結論
上記のリファレンスを読みseaborn.catplot(kind="swarm")
だったらseaborn.swarmplot(ax=ax)
にするだけ!!
詳しい説明
seabornでは基本的に、plotをする関数の引数のax
にmatplotlibで作成したaxを渡すだけで、図中にseabornで作成したグラフを入れることができます。
例えばkdeplot
関数では、
import seaborn as sns
y = "列名"
fig, axes = plt.subplots(2, 2, figsize=(2*1.4*3.8, 2*3.8),dpi=300, tight_layout=True, )
sns.kdeplot(hue="type", x=y, data=data, shade=False, ax=axes[0,0])
といったように引数にax
を追加して渡してやるだけでpltで作成したfigureにグラフを入れることができます。
しかし、catplot
関数に同様に引数にax
を追加して渡してやっても反映されないという問題に遭遇しました。
いろいろ調べた結果、
stripplot() (with kind="strip"; the default)
swarmplot() (with kind="swarm")
boxplot() (with kind="box")
violinplot() (with kind="violin")
boxenplot() (with kind="boxen")
pointplot() (with kind="point")
barplot() (with kind="bar")
countplot() (with kind="count")
https://seaborn.pydata.org/generated/seaborn.catplot.html
上記のリファレンスのように、catplot
関数のkind
に与えている文字+plot
という関数を代わりに使用してax
を使用すればいいということがわかりました。
なので、例えば2x2のグラフを表示させたい場合は、以下のように書くことができます。
import pandas as pd
import seaborn as sns
data = pd.DataFrame()
y = "列名"
fig, axes = plt.subplots(2, 2, figsize=(2*1.4*3.8, 2*3.8),dpi=300, tight_layout=True, )
#sns.catplot(x="type" ,y=y, data=data, kind="swarm", ax=axes[0,0])
sns.swarmplot(x="type" ,hue="type", y=y, data=data, legend=False, ax=axes[0,0])
#sns.catplot(x="type" ,y=y, data=data, kind="bar", ax=axes[0,1])
sns.barplot(x="type" ,y=y, data=data, ax=axes[0,1])
#sns.catplot(x="type" ,y=y, data=data, kind="violin", ax=axes[1,0])
sns.violinplot(x="type" ,y=y, data=data, ax=axes[1,0])
sns.kdeplot(hue="type", x=y, data=data, shade=False, ax=axes[1,1])
#fig.savefig("save" + ".png", format="png", dpi=300,transparent=True)
plt.show()
最後に
今回はSeabornのcatplotで引数にaxを渡してグラフを並べて表示させる方法を解説しました。
調べても英語の情報ばかり出てきてよくわからなかったので、参考になればと思って書きました。
それでは楽しいデータ分析ライフを!