Python
matplotlib

matplotlibの各種設定をコード内で記述するスタイルを考える


TL;DR

plt.rcParams['foo.bar']がいっぱい並ぶのが嫌なので、簡潔な書き方を考えてみました。


環境


  • python 3.6.5

  • matplotlib 3.0.1


コード内に設定を書く目的

matplotlibは、Pythonでちょっとグラフを見たいときには大変便利です。

ただ、きちんと体裁を整えようとすると様々な設定を追加する必要が生じます。

オススメの設定については以下の記事に詳しいです。

このような設定は~/.matplotlib/matplotlibrcにも書けますが、書きたいグラフに応じて設定を変えたいときにいちいち編集していては面倒です。コード中に設定があれば、目的に合わせてすぐにカスタマイズできます。


よく見る書き方

基本的には、上に挙げた記事にもあるようにplt.rcParams['foo.bar']='value'で設定をいじります。2つの記事から設定をお借りしてまとめてみたものがこちらです。

plt.rcParams['axes.grid'] = True

plt.rcParams['axes.linewidth'] = 1.2
plt.rcParams['font.family'] ='sans-serif'
plt.rcParams['font.sans-serif'] = ['Arial']
plt.rcParams['font.size'] = 8
plt.rcParams['grid.linestyle'] = '--'
plt.rcParams['grid.linewidth'] = 0.3
plt.rcParams['legend.edgecolor'] = 'black'
plt.rcParams['legend.fancybox'] = False
plt.rcParams['legend.framealpha'] = 1
plt.rcParams['xtick.direction'] = 'in'
plt.rcParams['xtick.major.width'] = 1.2
plt.rcParams['ytick.direction'] = 'in'
plt.rcParams['ytick.major.width'] = 1.2

この設定で試しに乱数をプロットしてみると、以下のようになります。

scatter.png

コード

import matplotlib.pyplot as plt

import numpy as np

x, y1, y2 = np.random.rand(3, 30)
fig, ax = plt.subplots()
ax.scatter(x, y1, label='data1')
ax.scatter(x, y2, label='data2')
ax.legend()

plt.show()


これで見た目が整いましたが、plt.rcParams['foo.bar']が多くて圧迫感のあるコードになってしまいました。もっと簡潔に書けないでしょうか。


dictionaryにまとめてdict.update()

plt.rcParamsは辞書なので、update()で上書きできます。

plt.rcParams.update({

'axes.grid' : True,
'axes.linewidth' : 1.2,
'font.family' :'sans-serif',
'font.sans-serif' : ['Arial'],
'font.size' : 8,
'grid.linestyle' : '--',
'grid.linewidth' : 0.3,
'legend.edgecolor' : 'black',
'legend.fancybox' : False,
'legend.framealpha' : 1,
'xtick.direction' : 'in',
'xtick.major.width' : 1.2,
'ytick.direction' : 'in',
'ytick.major.width' : 1.2
})

大量にplt.rcParamsを書く必要がなくなり、かなり見やすくなりました。

JSON設定ファイルを書いている気分です。


さらにまとめてみる

上の記述法でも十分ですが、まだlegend.などが重複しています。折角なのでこれもまとめてしまいましょう。

{

'key1.key2_a' : 'val_a'
'key1.key2_b' : 'val_b'
}

{

'key1' : {
'key2_a' : 'val_a',
'key2_b' : 'val_b'
}
}

と書きたいので、後者を前者に内包表記で変換します。

dict2rc = lambda dict: {f'{k1}.{k2}':v for k1,d in dict.items() for k2,v in d.items()}

'''
同じ処理をfor文で書くと
def dict2rc(dict) :
res = {}
for k1,d in dict.items():
for k2,v in d.items() :
res[f'{k1}.{k2}'] = v
return res
'''

これを用いると以下のように書けます。

plt.rcParams.update(dict2rc({

'axes' : {
'grid' : True,
'linewidth' : 1.2
},
'font' : {
'family' :'sans-serif',
'sans-serif' : ['Arial'],
'size' : 8
},
'grid' : {
'linestyle' : '--',
'linewidth' : 0.3
},
'legend' : {
'edgecolor' : 'black',
'fancybox' : False,
'framealpha' : 1
},
'xtick' : {
'direction' : 'in',
'major.width' : 1.2
},
'ytick' : {
'direction' : 'in',
'major.width' : 1.2
}
}))

横方向にはスッキリしますが、縦に長くなりすぎな感じもします。

見なくていい部分を折り畳み機能のあるエディタで隠しておけば見やすくなるかもしれません。