図の周りに文字情報を追記したい場合があります。plt.text
で文字情報を追記出来ますが、いちいち場所を指定しないといけないので面倒です。以下のコードはlegend
を使うことで場所の指定なしで追記しています。
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
%matplotlib inline
plt.rcParams['font.size'] = 15
def r2(y1, y2):
r2 = str(np.round(np.corrcoef(y1, y2)[0,1],3))
return r2
xa_train = [1,3,5,7]
xa_test = [2,4,6,9]
ya_train = xa_train + np.random.randn(4)
ya_test = xa_test + np.random.randn(4)
xb_train = [1,3,5,7]
xb_test = [2,4,6,9]
yb_train = xb_train + np.random.randn(4)
yb_test = xb_test + np.random.randn(4)
plt.figure()
plt.subplots_adjust(wspace=0.2, hspace=0.4)
gs = gridspec.GridSpec(2, 2, width_ratios=[1,1], height_ratios=[4,1])
plt.subplot(gs[0])
plt.scatter(xa_train, ya_train, color='k', label='train')
plt.scatter(xa_test, ya_test, color='r', label='test')
plt.xlim(0,10)
plt.ylim(0,10)
plt.xticks([0,2,4,6,8,10])
plt.yticks([0,2,4,6,8,10])
plt.plot([0,10],[0,10], color='gray', lw=0.5)
plt.grid()
plt.title('train')
plt.xlabel('measured')
plt.ylabel('predicted')
plt.subplot(gs[1])
plt.scatter(xb_train, yb_train, color='k',label='train')
plt.scatter(xb_test, yb_test, color='r', label='test')
plt.xlim(0,10)
plt.ylim(0,10)
plt.xticks([0,2,4,6,8,10])
plt.yticks([0,2,4,6,8,10])
plt.plot([0,0],[10,10])
plt.plot([0,10],[0,10], color='gray', lw=0.5)
plt.grid()
plt.title('test')
plt.xlabel('measured')
plt.tick_params(left=False,labelleft=False)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0)
plt.subplot(gs[2])
plt.tick_params(left=False, labelleft=False,bottom=False, labelbottom=False)
for i in ['top','bottom','left','right'] : plt.gca().spines[i].set_visible(False)
plt.scatter(0,0,label='$R^2_{train}=$'+r2(xa_train,ya_train)+'\n$R^2_{test}=$'+r2(xa_test,ya_test),alpha=0)
plt.legend(frameon=False, loc='upper left')
plt.subplot(gs[3])
plt.tick_params(left=False, labelleft=False,bottom=False, labelbottom=False)
for i in ['top','bottom','left','right'] : plt.gca().spines[i].set_visible(False)
plt.scatter(0,0,label='$R^2_{train}=$'+r2(xb_train,yb_train)+'\n$R^2_{test}=$'+r2(xb_test,yb_test),alpha=0)
plt.legend(frameon=False, loc='upper left')
plt.show()