LoginSignup
1
5

More than 3 years have passed since last update.

Matplotlibで図の下側に情報を追記する

Last updated at Posted at 2020-04-15

一覧表に戻る

図の周りに文字情報を追記したい場合があります。plt.textで文字情報を追記出来ますが、いちいち場所を指定しないといけないので面倒です。以下のコードはlegendを使うことで場所の指定なしで追記しています。

image.png

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
%matplotlib inline
plt.rcParams['font.size'] = 15

def r2(y1, y2):
    r2 = str(np.round(np.corrcoef(y1, y2)[0,1],3))
    return r2

xa_train = [1,3,5,7]
xa_test  = [2,4,6,9]
ya_train = xa_train + np.random.randn(4)
ya_test  = xa_test  + np.random.randn(4)

xb_train = [1,3,5,7]
xb_test  = [2,4,6,9]
yb_train = xb_train + np.random.randn(4)
yb_test  = xb_test  + np.random.randn(4)

plt.figure()
plt.subplots_adjust(wspace=0.2, hspace=0.4)

gs = gridspec.GridSpec(2, 2, width_ratios=[1,1], height_ratios=[4,1])

plt.subplot(gs[0])
plt.scatter(xa_train, ya_train, color='k', label='train')
plt.scatter(xa_test, ya_test, color='r', label='test')
plt.xlim(0,10)
plt.ylim(0,10)
plt.xticks([0,2,4,6,8,10])
plt.yticks([0,2,4,6,8,10])
plt.plot([0,10],[0,10], color='gray', lw=0.5)
plt.grid()
plt.title('train')
plt.xlabel('measured')
plt.ylabel('predicted')

plt.subplot(gs[1])
plt.scatter(xb_train, yb_train, color='k',label='train')
plt.scatter(xb_test, yb_test, color='r', label='test')
plt.xlim(0,10)
plt.ylim(0,10)
plt.xticks([0,2,4,6,8,10])
plt.yticks([0,2,4,6,8,10])
plt.plot([0,0],[10,10])
plt.plot([0,10],[0,10], color='gray', lw=0.5)
plt.grid()
plt.title('test')
plt.xlabel('measured')
plt.tick_params(left=False,labelleft=False)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0)

plt.subplot(gs[2])
plt.tick_params(left=False, labelleft=False,bottom=False, labelbottom=False)
for i in ['top','bottom','left','right'] : plt.gca().spines[i].set_visible(False)
plt.scatter(0,0,label='$R^2_{train}=$'+r2(xa_train,ya_train)+'\n$R^2_{test}=$'+r2(xa_test,ya_test),alpha=0)
plt.legend(frameon=False, loc='upper left')

plt.subplot(gs[3])
plt.tick_params(left=False, labelleft=False,bottom=False, labelbottom=False)
for i in ['top','bottom','left','right'] : plt.gca().spines[i].set_visible(False)
plt.scatter(0,0,label='$R^2_{train}=$'+r2(xb_train,yb_train)+'\n$R^2_{test}=$'+r2(xb_test,yb_test),alpha=0)
plt.legend(frameon=False, loc='upper left')

plt.show()
1
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
5