0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

matplotlibで複数系列の点をプロットして回帰式を描いた

Last updated at Posted at 2022-07-27

はじめに

エクセルでやってみると結構めんどくさかったので書いた。
以下のようなグラフを作成できる。
TN.png

グラフ作成

解説

pandasでcsvを読み込んでmatplotを使って読み込んだ全ての値を用いて回帰式を作成。
回帰式を表示した後に, csvのmarkの項目に従って点の形を変えてプロット。

データセットは以下の形式のものを扱う(.csv)。
xValue, yValue, mark
0.341, 0.6546, winter
0.421, 0.5341, winter
0.113, 0.1536, summer
0.114, 0.1342, summer
.
.
.

ソースコード

from turtle import color
import matplotlib.pyplot as plt
import numpy.polynomial.polynomial as P
import pandas as pd
import sklearn.metrics as metrics

def func(_x, _a, _b):
    return _a*_x + _b

#回帰式計算
def regression(_x, _y):
    coef = P.polyfit(_x, _y,1)
    #傾き
    a = coef[1]
    #切片
    b = coef[0]
    #近似直線の決定係数
    r2 = metrics.r2_score(_y, func(_x, a, b))
    return a, b, r2

#式情報をグラフに表示
def putEquationText(_a, _b, _r2, _start, _end, _ax):
    #回帰式をグラフ上に表示(表示場所も決定)
    text_x = _start*0.4 + _end*0.6
    if   _a > 0 and _b > 0:
        _ax.text(text_x, func(text_x, _a, _b)*0.8, f'y = {round(_a, 2)}x + {round(_b, 2)}', fontsize = 12)
    elif _a > 0 and _b < 0:
        _ax.text(text_x, func(text_x, _a, _b)*0.8, f'y = {round(_a, 2)}x - {-1*round(_b, 2)}', fontsize = 12)
    elif _a < 0 and _b > 0:
        _ax.text(text_x, func(text_x, _a, _b)*0.8, f'y = -{-1*round(_a, 2)}x + {round(_b, 2)}', fontsize = 12)
    elif _a < 0 and _b < 0:
        _ax.text(text_x, func(text_x, _a, _b)*0.8, f'y = -{-1*round(_a, 2)}x - {-1*round(_b, 2)}', fontsize = 12)
    else:
        _ax.text(text_x, func(text_x, _a, _b)*0.8, f'y = -{-1*round(_a, 2)}x', fontsize = 12)
    #決定係数の表示
    _ax.text(text_x, func(text_x, _a, _b)*0.65, r'$r^2 = {}$'.format(round(_r2, 2)), fontsize = 12)

#点をグラフ上に表示
def putSeveralMark(_df, _x, _y, _mark ,_plt):
    #要素ごとにプロット
    markers = ["x","+",",", "v", "^", "<", ">", "1", "2", "3"]
    marks = _df[_mark].unique()
    for mark, i in zip(marks, range(len(marks))):
        df_each = _df[_df[_mark] == mark]
        _plt.scatter(df_each[_x], df_each[_y], s = 100, marker=markers[i], color = "black")

#csvをpandasで読み込み
path = "input.csv"
df = pd.read_csv(path)
#列名取得
df_column = df.columns.values
x_column = df_column[0]
y_column = df_column[1]
mark_column = df_column[2]

#x座標の最大値, 最小値から直線の表示域を決定
start = df[x_column].min()*0.9
end = df[x_column].max()*1.1

#グラフ定義
plt.rcParams['figure.subplot.bottom'] = 0.15
plt.rcParams['lines.linewidth'] = 3
fig, ax = plt.subplots()

#回帰式を計算
a, b, r2 = regression(df[x_column], df[y_column])

#回帰式を表示
plt.plot([start, end], [a*start+b, a*end+b], color = "blue")

#軸ラベルを表示
ax.set_xlabel(r'$flow(m^3/s)$', fontsize = 12)
ax.set_ylabel(r'$TN(mg/L)$', fontsize = 12)

#直線の式, 決定係数を表示
putEquationText(a, b, r2, start, end, ax)

#点をプロット
putSeveralMark(df, x_column, y_column, mark_column ,plt)

#画像を保存
plt.savefig("output.png", dpi=200, bbox_inches="tight", pad_inches=0.1)
plt.show()

おわりに

適宜改善して使ってくれると喜びます。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?