matplotlibでうまいことヒートマップを描く
ヒートマップを書くのにはseabornがいい感じなんですが、うまいこと痒いところに手が届きませんでした。なので、matplotlibで何とかしました。色々と無理矢理な部分もあるので、そういう方法もあるのか...くらいに思っていただければ幸いです。
データの準備
np.random.seed(0)
x = np.random.randint(-50,50,100)
y = np.random.randint(1000,2000,100)
z = np.random.randint(0,10,100)
上記の3次元データに対して、xとyのある区間の間で、一定の条件を満たしているzの、z全体からの割合を図にしたいというのが元のモチベーションです。
データの集計
x軸は25ずつ、y軸は200ずつの区間に分けて集計します。
x_edges = list(range(-50,50+10,10))
y_edges = list(range(1000,2000+100,100))
obj_data = np.histogram2d(x[z_threshold], y[z_threshold], bins=[x_edges, y_edges])
all_data = np.histogram2d(x, y, bins=[x_edges, y_edges])
# 配列はnp.histogram2dの返り値の0番目に入っています
data = obj_data[0]/all_data[0]
# 軸がずれているので注意して修正
# np.histogram2dのところでうまく調整すればいけるのかも...?
data_ = np.flipud(data.T)
ここでやりたいことは、次の2つのグラフの割合を表すヒートマップを描画したい、ということです。colorbarの表示を諦めたのでめちゃくちゃ見にくいですが...
fig = plt.figure(figsize=8,16)
ax1 = plt.subplot(121)
ax2 = plt.subplot(122)
ax1.plot(obj_data)
ax1.set_title('z>=8')
ax1.set_xlabel('x')
ax1.set_ylabel('y')
ax2.plot(all_data)
ax2.set_title('all z')
ax2.set_xlabel('x')
ax2.set_ylabel('y')
各タイルごとで、z>=8/all z
の割合を描画するには??
seabornで描画
まずは、seabornを使って書きます。上記のdata
を与えるだけで欲しい図に近いヒートマップが手に入ります。
hm = sns.heatmap(data,annot=True)
annot=True
とすることで数値を表示できるが、軸の調整がうまいことできない...
軸は再設定できますが、各タイルの真ん中に目盛りが欲しいわけではなく、例えば一番左下のタイルなら、
「xが-50から-40の間、yが1000から1200の間にあるデータのうち、z>=8である割合は0.2」という情報が欲しいです。
なので、以下、matplotlibで数値の表示まで無理矢理いきます。
matplotlibで描画
plt.rcParams['figure.figsize']=(8,8)
ex = [-50,50,1000,2000]
plt.imshow(data_,extent=ex,aspect=(ex[1]-ex[0])/(ex[3]-ex[2]),cmap='autumn_r')
plt.colorbar(shrink=0.82)
ここでextent
を設定することで、軸ラベルを再認識させることができます。
ただ、かなりやっつけな仕事がボロボロ出てきて、
-
extent
を設定した結果、横軸の範囲が100、縦軸の範囲が1000になって図の縦横比が大幅にずれる。そのためaspect
の設定でaspect比を合わせる。 - あとの文字表示の都合上、暗い色だと大変なので明るい
cmap
を設定している - aspect比をいじった結果、colorbarの長さがおかしくなるので、
shrink
の設定でむりやり(数値を自分でいじって)調整
わあ...姑息...このへんのその場しのぎの調整は後で余裕が出た時に調べ直します...
文字を乗せる
ex = [-50,50,1000,2000]
plt.imshow(data_,extent=ex,cmap='autumn_r')
plt.gca().set_aspect((ex[1]-ex[0])/(ex[3]-ex[2]))
plt.colorbar(shrink=0.82)
xs, ys = np.meshgrid(np.array(x_edges[:-1])+10/2,np.array(y_edges[:-1])+200/2,indexing='ij')
for (x,y,text) in zip(xs.flatten(), ys.flatten(), data.flatten()):
plt.text(x,y,'{}'.format(np.round(text,2)),horizontalalignment='center',verticalalignment='center')
上記で表示したグラフとは別にnp.meshgrid
でメッシュを用意して、その座標にテキストを置いていく。
ただし、用意したx_edges,y_edges
をそのまま使ってしまうと、各々のタイルの継ぎ目の部分に文字が置かれてしまうため、タイルのど真ん中を指定するメッシュを用意した。この辺りの手法については、こちらのページを参考にさせていただきました。
xs, ys = np.meshgrid(np.array(x_edges[:-1])+10/2,np.array(y_edges[:-1])+200/2,indexing='ij')
テキスト表示に使った配列は、向きを変えたdata_ = np.flipud(data.T)
ではなく、np.histogram2d
から計算したそのままのdata
であることに注意。
この辺りから考えるにしても、もう少しデータの配列についてうまいことしなきゃいけなさそう。
今後の課題
今回その場しのぎで解決した部分をちゃんと調整できるように調べて追記する。
もし部分的にでも解決策をご存知の方いれば、ご教授いただけると嬉しいです。