モチベーション
Simon J. D. Prince
"Computer Vision: Models, Learning, and Inference"
(http://www.computervisionmodels.com/)
の中に出てくるヒートマップを、自分でも出してみたい!
実装
コード全体はこちら。(※補足参照)
(2019/6/9追記)
matplotlib.mlab.bivariate_normal
の代わりにscipy.stats.multivariate_normal
を使用した場合のコードも追加しました。
実装といっても、高級なことは何もしてません。まずはインポートから。
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.mlab as mlab
import seaborn as sns
seaborn
のインポートは必須ではありません。個人的な好みです。
続いてメッシュグリッドを作ります。
# メッシュグリッドの作成
X, Y = np.mgrid[-5:5:200j, -5:5:200j]
そしてこれらの格子点$(X, Y)$に対して、ガウス分布の値を求めます。
# 2次元ガウス分布
Z = mlab.bivariate_normal(X, Y, 2., 1., 0., 0.)
なお、bivariate_normal
のI/Fは↓です。
matplotlib.mlab.bivariate_normal(X, Y, sigmax=1.0, sigmay=1.0, mux=0.0, muy=0.0, sigmaxy=0.0)
そして可視化。
# ヒートマップを描画
plt.figure(figsize=(8, 6))
plt.pcolor(X, Y, Z, cmap=plt.cm.hot)
plt.colorbar()
plt.title('2次元ガウス分布')
plt.xlabel('$x$', size=12)
plt.ylabel('$y$', size=12)
plt.xlim((-5, 5))
plt.ylim((-5, 5))
plt.show()
するとこんな感じで表示されます。
よしよし。
※ちなみにplt.contour
を使うと、等高線を出すこともできます。(詳細はコード参照)
補足
- 2019/3/2現在、
bivariate_normal
については以下のワーニングが出ています。何に置き換えればよいのか?は不明。
Deprecated since version 2.2: The bivariate_normal function was deprecated in Matplotlib 2.2 and will be removed in 3.1.
(2019/6/9追記)
scipy.stats.multivariate_normal
を使用するのがよいようです。(Thanks to @Cartman0)
※
ただ、グラフの出力結果をmlab.bivariate_normal
と比べると、微妙に異なっています。
この記事の趣旨から外れるため、ここではこれ以上深入りしませんが、詳細が分かれば別途展開します。