Python の matplotlib でヒストグラムをよく描くわけですが、その縦軸・横軸がときどき気に入らないので、それを微調整するためのメモです。
hist の縦軸と横軸
例として次のようなグラフをお見せします。
%matplotlib inline
import matplotlib.pyplot as plt
from scipy import stats
norm_rvs = stats.norm.rvs(loc=50, scale=30, size=100, random_state=0)
plt.hist(norm_rvs, bins=10, alpha=0.5, ec='navy')
plt.show()
これを見て、
- うえ〜ん、ヒストグラムの棒の切れ目が中途半端な場所にあって気持ち悪いヨー!
- うえ〜ん、縦軸の目盛りが整数じゃないと気持ち悪いヨー!
ってなるわけです。
縦軸の目盛りを整数にしたい
次のようにすれば、ヒストグラムの棒の切れ目と、高さに関する情報が得られます。
Y, X, _ = plt.hist(norm_rvs, bins=10, alpha=0.5, ec='navy')
print(X)
print(Y)
plt.show()
[-26.58969448 -12.12146116 2.34677216 16.81500548 31.2832388
45.75147212 60.21970544 74.68793876 89.15617208 103.6244054
118.09263872]
[ 1. 5. 7. 13. 17. 18. 16. 11. 7. 5.]
その情報を用いて縦軸を整数にしてみます。
import numpy as np
Y, X, _ = plt.hist(norm_rvs, bins=10, alpha=0.5, ec='navy')
y_max = int(max(Y)) + 1
plt.yticks(np.arange(0, y_max, 2)) # 1刻みにしても見にくいので2刻みにします
plt.show()
棒の区切りをいい感じにしたい
横軸の範囲を指定して、binの本数をいい感じに調整します。
Y, X, _ = plt.hist(norm_rvs, bins=13, alpha=0.5, ec='navy', range=(-10, 120))
print(X)
print(Y)
y_max = int(max(Y)) + 1
plt.yticks(np.arange(0, y_max, 2))
plt.show()
[-10. 0. 10. 20. 30. 40. 50. 60. 70. 80. 90. 100. 110. 120.]
[ 3. 5. 6. 10. 11. 9. 15. 13. 9. 6. 5. 5. 2.]
複数のヒストグラムをいい感じにしたい
さて、複数のヒストグラムを並べて比較したいときがあると思いますが
norm_rvs2 = stats.norm.rvs(loc=75, scale=55, size=100, random_state=0)
plt.hist(norm_rvs, bins=10, alpha=0.5, ec='navy')
plt.hist(norm_rvs2, bins=10, alpha=0.5, ec='red')
plt.show()
こんなふうに気持ち悪いヨ〜!ってなりがちですね。これも同じようにいい感じにしてみましょう。
bins = 20
range=(-50, 200)
Y1, X1, _ = plt.hist(norm_rvs, bins=bins, alpha=0.5, ec='navy', range=range)
Y2, X2, _ = plt.hist(norm_rvs2, bins=bins, alpha=0.5, ec='red', range=range)
y_max = int(max(max(Y1), max(Y2))) + 1
plt.yticks(np.arange(0, y_max, 2))
plt.show()
個人的には、次のように縦に並べる方が好きです。
bins = 20
range=(-50, 200)
fig, axes = plt.subplots(nrows=2, ncols=1, figsize=(8,8))
Y1, X1, _ = axes[0].hist(norm_rvs, bins=bins, alpha=0.5, ec='navy', range=range)
Y2, X2, _ = axes[1].hist(norm_rvs2, bins=bins, alpha=0.5, ec='red', range=range)
y_max = int(max(max(Y1), max(Y2))) + 1
axes[0].set_ylim([0, y_max])
axes[1].set_ylim([0, y_max])
axes[0].set_yticks(np.arange(0, y_max, 2))
axes[1].set_yticks(np.arange(0, y_max, 2))
plt.show()
現場からは以上です!