LoginSignup
3
4

More than 3 years have passed since last update.

matplotlibで割合を表すヒートマップの作成

Posted at

matplotlibでうまいことヒートマップを描く

ヒートマップを書くのにはseabornがいい感じなんですが、うまいこと痒いところに手が届きませんでした。なので、matplotlibで何とかしました。色々と無理矢理な部分もあるので、そういう方法もあるのか...くらいに思っていただければ幸いです。

こんな図を描きたい。image.png

データの準備

np.random.seed(0)

x = np.random.randint(-50,50,100)
y = np.random.randint(1000,2000,100)
z = np.random.randint(0,10,100)

上記の3次元データに対して、xとyのある区間の間で、一定の条件を満たしているzの、z全体からの割合を図にしたいというのが元のモチベーションです。

データの集計

x軸は25ずつ、y軸は200ずつの区間に分けて集計します。

x_edges = list(range(-50,50+10,10))
y_edges = list(range(1000,2000+100,100))
obj_data = np.histogram2d(x[z_threshold], y[z_threshold], bins=[x_edges, y_edges])
all_data = np.histogram2d(x, y, bins=[x_edges, y_edges])

# 配列はnp.histogram2dの返り値の0番目に入っています
data = obj_data[0]/all_data[0]
# 軸がずれているので注意して修正
# np.histogram2dのところでうまく調整すればいけるのかも...?
data_ = np.flipud(data.T)

ここでやりたいことは、次の2つのグラフの割合を表すヒートマップを描画したい、ということです。colorbarの表示を諦めたのでめちゃくちゃ見にくいですが...

fig = plt.figure(figsize=8,16)
ax1 = plt.subplot(121)
ax2 = plt.subplot(122)
ax1.plot(obj_data)
ax1.set_title('z>=8')
ax1.set_xlabel('x')
ax1.set_ylabel('y')
ax2.plot(all_data)
ax2.set_title('all z')
ax2.set_xlabel('x')
ax2.set_ylabel('y')

image.png
各タイルごとで、z>=8/all zの割合を描画するには??

seabornで描画

まずは、seabornを使って書きます。上記のdataを与えるだけで欲しい図に近いヒートマップが手に入ります。

hm = sns.heatmap(data,annot=True)

image.png

annot=Trueとすることで数値を表示できるが、軸の調整がうまいことできない...

軸は再設定できますが、各タイルの真ん中に目盛りが欲しいわけではなく、例えば一番左下のタイルなら、
「xが-50から-40の間、yが1000から1200の間にあるデータのうち、z>=8である割合は0.2」という情報が欲しいです。
なので、以下、matplotlibで数値の表示まで無理矢理いきます。

matplotlibで描画

plt.rcParams['figure.figsize']=(8,8)
ex = [-50,50,1000,2000]
plt.imshow(data_,extent=ex,aspect=(ex[1]-ex[0])/(ex[3]-ex[2]),cmap='autumn_r')
plt.colorbar(shrink=0.82)

image.png

ここでextentを設定することで、軸ラベルを再認識させることができます。
ただ、かなりやっつけな仕事がボロボロ出てきて、
- extentを設定した結果、横軸の範囲が100、縦軸の範囲が1000になって図の縦横比が大幅にずれる。そのためaspectの設定でaspect比を合わせる。
- あとの文字表示の都合上、暗い色だと大変なので明るいcmapを設定している
- aspect比をいじった結果、colorbarの長さがおかしくなるので、shrinkの設定でむりやり(数値を自分でいじって)調整

わあ...姑息...このへんのその場しのぎの調整は後で余裕が出た時に調べ直します...

文字を乗せる

ex = [-50,50,1000,2000]
plt.imshow(data_,extent=ex,cmap='autumn_r')
plt.gca().set_aspect((ex[1]-ex[0])/(ex[3]-ex[2]))
plt.colorbar(shrink=0.82)

xs, ys = np.meshgrid(np.array(x_edges[:-1])+10/2,np.array(y_edges[:-1])+200/2,indexing='ij')
for (x,y,text) in zip(xs.flatten(), ys.flatten(), data.flatten()):
    plt.text(x,y,'{}'.format(np.round(text,2)),horizontalalignment='center',verticalalignment='center')

上記で表示したグラフとは別にnp.meshgridでメッシュを用意して、その座標にテキストを置いていく。
ただし、用意したx_edges,y_edgesをそのまま使ってしまうと、各々のタイルの継ぎ目の部分に文字が置かれてしまうため、タイルのど真ん中を指定するメッシュを用意した。この辺りの手法については、こちらのページを参考にさせていただきました。

xs, ys = np.meshgrid(np.array(x_edges[:-1])+10/2,np.array(y_edges[:-1])+200/2,indexing='ij')

テキスト表示に使った配列は、向きを変えたdata_ = np.flipud(data.T)ではなく、np.histogram2dから計算したそのままのdataであることに注意。
この辺りから考えるにしても、もう少しデータの配列についてうまいことしなきゃいけなさそう。

結果

image.png
無事こんな図が描けました。

今後の課題

今回その場しのぎで解決した部分をちゃんと調整できるように調べて追記する。

もし部分的にでも解決策をご存知の方いれば、ご教授いただけると嬉しいです。

3
4
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4