Matplotlib のscatter(散布図)から、各点の座標を取得する
上記を調べてもヒットしなかったのでメモしておく。
plt.scatter() やsns.scatteplot()で作図した散布図の各点の座標を取得する。
scatterでは戻り値が「PathCollection」なのだが、ここから各点の座標を取得する方法を調べても公式ドキュメントに記されている構造が難解で辿り着けない。
-
seabornを用いる(sns.scatterplot())
- 戻り値はaxes。axes.collections[0].get_offsets()、と辿る
→[[$x_1$ $y_1$] [$x_2$ $y_2$] [...]...] のリストを取得
- 戻り値はaxes。axes.collections[0].get_offsets()、と辿る
ax=sns.scatterplot(x=x, y=experiment_data)
points=ax.collections[0].get_offsets()
for point in points:
print(f"x : {point[0]}, y : {point[1]}")
plt.show()
→point[0]が各$x_i$、point[1]が各$y_i$
-
plt.scatterを用いる
- 戻り値がPathCollectionなので、 .get_offsets() と書けば良い
→[[$x_1$ $y_1$] [$x_2$ $y_2$] [...]...] のリストを取得
- 戻り値がPathCollectionなので、 .get_offsets() と書けば良い
pc=plt.scatter(x,experiment_data)
points=pc.get_offsets()
for point in points:
print(f"x : {point[0]}, y : {point[1]}")
plt.show()
→point[0]が各$x_i$、point[1]が各$y_i$
以下、実験誤差を表示するイメージのコード。
seabornのscatterplot()を用いた場合と、plt.scatter()を用いた場合の2個を書く。
目的は、各点から理論値直線への垂線(赤点線)を引くこととする。
また、値の確認用に、定義したx,yと散布図から取得したx,yをそれぞれprint出力している。
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
sns.set()
#define x value, 10 points
x = np.arange(0,10,1)
#define y value, 10 points. y= x + random , random is in normal distribution
rng = np.random.default_rng()
experiment_data = x + rng.normal(0,1,10)
print("Show defined each x & y values")
for i,j in zip(x,experiment_data):
print(f"x : {i}, y : {j}")
print()
#========================================
#===plot using "seaborn.scatterplot()"===
ax=sns.scatterplot(x=x, y=experiment_data)
#draw line of y=x. as ideal line.
plt.axline((0,0),slope=1,color="r", linestyle="-", linewidth=0.8)
ax.axis("square")
#get each xy coordinates of scatterplot.
points=ax.collections[0].get_offsets()
print("Retrieve values from sns.scatterplot():")
for point in points:
print(f"x : {point[0]}, y : {point[1]}")
#draw dotted line onto the ideal line
plt.vlines(point[0],point[0],point[1],linestyle="--", color="tab:red",linewidth=0.8)
ax.set_xlabel("x : actual")
ax.set_ylabel("y : experimental")
plt.annotate("retlieve each xy value from scatter points \n and draw dotted line", xy=(0.1,0.9), xycoords="axes fraction",color="r", fontsize=8)
plt.title("seaborn.scatterplot() example")
plt.show()
#===================================
#====plot using "plt.scatter()"=====
pc=plt.scatter(x,experiment_data)
#draw line of y=x. as ideal line.
plt.axline((0,0),slope=1,color="r", linestyle="-", linewidth=0.8)
plt.axis("square")
#get each xy coordinates of scatterplot.
points=pc.get_offsets()
print("Retrieve values from plt.scatter():")
for point in points:
print(f"x : {point[0]}, y : {point[1]}")
#draw dotted line onto the ideal line
plt.vlines(point[0],point[0],point[1],linestyle="--", color="tab:red",linewidth=0.8)
plt.xlabel("x : actual")
plt.ylabel("y : experimental")
plt.annotate("retlieve each xy value from scatter points \n and draw dotted line", xy=(0.1,0.9), xycoords="axes fraction",color="r", fontsize=8)
plt.title("plt.scatter() example")
plt.show()
出力:
Show defined each x & y values
x : 0, y : -0.3636407569331496
x : 1, y : 0.2513029225750196
x : 2, y : 3.0583493975231795
x : 3, y : 3.480049609935434
x : 4, y : 4.49850625721525
x : 5, y : 4.780825044490773
x : 6, y : 3.9428596747818956
x : 7, y : 7.693574169170814
x : 8, y : 8.982012651087544
x : 9, y : 10.763739268926154
Retrieve values from sns.scatterplot():
x : 0.0, y : -0.3636407569331496
x : 1.0, y : 0.2513029225750196
x : 2.0, y : 3.0583493975231795
x : 3.0, y : 3.480049609935434
x : 4.0, y : 4.49850625721525
x : 5.0, y : 4.780825044490773
x : 6.0, y : 3.9428596747818956
x : 7.0, y : 7.693574169170814
x : 8.0, y : 8.982012651087544
x : 9.0, y : 10.763739268926154
Retrieve values from plt.scatter():
x : 0.0, y : -0.3636407569331496
x : 1.0, y : 0.2513029225750196
x : 2.0, y : 3.0583493975231795
x : 3.0, y : 3.480049609935434
x : 4.0, y : 4.49850625721525
x : 5.0, y : 4.780825044490773
x : 6.0, y : 3.9428596747818956
x : 7.0, y : 7.693574169170814
x : 8.0, y : 8.982012651087544
x : 9.0, y : 10.763739268926154