6
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

ラクスパートナーズAdvent Calendar 2023

Day 12

継続的にモデルを活用すために:統計手法とPythonで実装

Last updated at Posted at 2023-12-11

1.あともう少し!メリークリスマス!

この記事は、ラクスパートナーズ Advent Calendar 2023 12日目の記事です
弊社はITに関わるすべての人たちを応援する楽楽パートナーを掲げ、
日々お客様の業務の支援を行なっております

2.今回扱うデータセット

カリフォルニアの住宅価格になります
今回は、各地点の平均の住宅価格の予測を行います

longitude:緯度

latitude:経度

housing_median_age:築年数の中央値

total_rooms:合計部屋数

total_bedrooms:合計部屋数

population:人口

households:世帯数

median_income:中央値所得

median_house_value:住宅価格の中央値 ⇦今回予測するのはこれ!

ocean_proximity:海から近いかどうか(4カテゴリ)

2.1 EDA

#必要なライブラリのインポート
import pandas as pd
from sklearn.preprocessing import LabelEncoder
import matplotlib.colors as mcolors
import statsmodels.api as sm
import numpy as np
import math
import matplotlib.pyplot as plt
#データの読み込み
df = pd.read_csv("/content/drive/MyDrive/housing_price/housing.csv")
# 築年数の可視化
plt.hist(df["housing_median_age"], bins=20, color='skyblue')  # ヒストグラムを作成、20本のビンを使用していますが、必要に応じて調整できます
plt.xlabel('Housing Median Age')  # x軸のラベル
plt.ylabel('Frequency')  # y軸のラベル
plt.title('Distribution of Housing Median Age')  # グラフのタイトル
plt.grid(axis='y')  # y軸にのみグリッド線を表示
plt.show()  # グラフを表示

スクリーンショット 2023-12-08 0.02.52.png

#  住宅価格の分布
plt.figure(figsize=(8, 6))
plt.hist(df["median_house_value"], bins=30, color='skyblue')
plt.xlabel('Median House Value')
plt.ylabel('Frequency')
plt.title('Distribution of Median House Value')
plt.grid(axis='y')
plt.show()

スクリーンショット 2023-12-08 0.03.55.png

# 経度緯度を用いて住宅の価格を可視化
plt.scatter(df["longitude"],df["latitude"],c=df["median_house_value"], cmap='viridis')
# plt.scatter(city_center_x,city_center_y,marker='*',s=800,c="red")
#plt.plot(x, poly_equation(x), color='red', label='近似直線')
plt.colorbar(label='Median House Value')  # カラーバーのラベルを設定
plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.title('Median House Value by Location')
plt.show()

左上に住宅価格が高い地域ががあります。
海に隣接している州のため、基本的に海辺は高いですね
スクリーンショット 2023-12-08 0.05.34.png

# 海に近いかどうかのカテゴリを使って可視化
sns.scatterplot(x='longitude', y='latitude', hue='ocean_proximity', data=df, palette='tab10')
plt.title('Ocean Proximity Scatter Plot')
plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.legend(title='Ocean Proximity', bbox_to_anchor=(1.05, 1), loc='upper left')  # 凡例を図の外側に移動

plt.show()

スクリーンショット 2023-12-08 0.08.33.png

基本的に海辺に近ければ近いほど住宅は高い傾向にはあるが、
一部中心市街地は住宅価格が高い
将来的にも海辺の価格は低くなることはないと予想

特徴量の作成

散布図より 中心市街地の緯度が-121 経度が38であることがわかるため、その値を用いて その中で住宅価格が最大の地域を探す

high_median_house_value_ISLAND = df[(df["longitude"] >=-121) & (df["latitude"] >=38)]["median_house_value"].max()

land_df = df[(df["longitude"] >=-121) & (df["latitude"] >=38)]

city_center_x = land_df[land_df["median_house_value"] ==high_median_house_value_ISLAND]["longitude"].values
city_center_y = land_df[land_df["median_house_value"] ==high_median_house_value_ISLAND]["latitude"].values

これを使って、下の図を見てください。

スクリーンショット 2023-12-08 0.19.13.png

この赤線に関しては、海辺に近いカテゴリのみのデータを使って、OLSを使って一次関数を求めました。
下のコードがそれになります。

import matplotlib.pyplot as plt
import numpy as np

# データの準備
x = df[df["ocean_proximity"] =="NEAR OCEAN"]["longitude"].values
y = df[df["ocean_proximity"] =="NEAR OCEAN"]["latitude"].values

# 最小二乗法を使って近似直線の係数を求める
degree = 1  # 一次式を意味する
coefficients = np.polyfit(x, y, degree)

# 近似直線の式を表示
poly_equation = np.poly1d(coefficients)
print("近似直線の式:", poly_equation)

近似直線を求め、そこから点と線の距離を使って、近似直線までの距離を算出します。
また、中心市街地に近ければ近いほど住宅価格が高くなるので、
中心市街地を円の中心として、そこから半径の距離を算出し、そちらも説明変数とします。

下が一次直線からの距離の算出と中心市街地からの距離を算出し、説明変数を取得したものです。

#元々のデータから扱うデータのみをx_featuresに入れる
x_features = df[["housing_median_age","median_income"]].values

# 傾きと切片を取得
m = poly_equation[1]  # 1次の係数が傾き
c = poly_equation[0]  # 定数項が切片

# 新しい列を初期化
num_rows = len(df)
distance_ocean_column = np.zeros((num_rows, 1))  # distance_oceanの列
distance_city_center_column = np.zeros((num_rows, 1))  # distance_city_centerの列


for i,(longitude , latitude) in enumerate(df[["longitude","latitude"]].values):
  # 垂線の傾きを計算
  perpendicular_slope = -1 / m

  # 垂線の切片を計算
  perpendicular_intercept = latitude - perpendicular_slope * longitude

  # 垂線と直線の交点を計算して x0, y0 を求める
  x0 = (perpendicular_intercept - c) / (m - perpendicular_slope)
  y0 = m * x0 + c

  # 与えられた点と交点の距離を計算
  distance_ocean = np.sqrt((longitude - x0)**2 + (latitude - y0)**2)

  # 中心市街地からの距離を求める
  distance_city_center = math.sqrt((city_center_x - longitude)**2 + (city_center_y - latitude)**2)

  # distance_oceanとdistance_city_centerをそれぞれ列に追加
  distance_ocean_column[i] = distance_ocean
  distance_city_center_column[i] = distance_city_center

# x_featuresに新しい列を追加
x_features = np.concatenate((x_features, distance_ocean_column, distance_city_center_column), axis=1)


# 海に近い場所の4つ目の特徴量は0にする
for i in df[df['ocean_proximity'] != '<1H OCEAN'].index:
  x_features[i,3] = 0

予測モデルの作成

from sklearn.linear_model import Lasso
from sklearn.datasets import make_regression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error

# データをトレーニングセットとテストセットに分割
X_train, X_test, y_train, y_test = train_test_split(x_features, y, test_size=0.2, random_state=42)

# Lasso回帰モデルのインスタンス化
lasso = Lasso(alpha=0.2)  # alphaは正則化パラメータで調整可能

# モデルのトレーニング
lasso.fit(X_train, y_train)


# 係数の取得
coefficients = lasso.coef_

# 各特徴量の係数を表示
for i, coef in enumerate(coefficients):
    print(f"特徴量 {i + 1}: {coef}")

特徴量 1: 941.6964339359415
特徴量 2: 37997.160731779964
特徴量 3: -62248.12034008109
特徴量 4: 459.1748923825705

# テストセットでの予測
predictions = lasso.predict(X_test)

# モデルの評価
score = lasso.score(X_test, y_test)
print(f"モデルのスコア: {score}")

モデルのスコア: 0.5723663677004089

3.MLOpsで使う統計検定手法(コルモゴロフスミノロフ検定)

例えば市街地の発展によって、新たに商業施設ができたとする

その場合、1H OCEANの住宅価格が大きく変わる

急に500000以上の値が増えたらどうなるのか?

今回、1つの中心市街地としてそこからの距離を特徴量にした。

それが新たに1つ増えることになることも考慮しなくてはいけない。

もし考慮しなかった場合、本来住宅価格が高くなっている場所に誤った予測値を出すことになってしまう。

日々得られるデータから住宅価格の推移を得られるのであれば
それを日々チェックして、継続できるMLモデルを作成することが可能である。

#擬似データの作成
df_dammy = df.copy()
mask = ((df['longitude'] >= -121.74) & (df['longitude'] <= -120.74) &
        (df['latitude'] >= 38.87) & (df['latitude'] <= 39.87))
df_dammy.loc[mask, 'median_house_value'] = 500000

# データの可視化
plt.scatter(df_dammy["longitude"],df_dammy["latitude"],c=df_dammy["median_house_value"], cmap='viridis')
plt.scatter(city_center_x,city_center_y,marker='*',s=800,c="red")
plt.plot(df["longitude"], poly_equation(df["longitude"]), color='red', label='近似直線')
plt.colorbar(label='Median House Value')  # カラーバーのラベルを設定
plt.xlabel('Longitude')
plt.ylabel('Latitude')
plt.title('Median House Value by Location')
plt.show()

このように、中心市街地の隣接の町に何か複合施設ができたのでしょうか(すっとぼけ)
住宅価格が一部急上昇してます。
スクリーンショット 2023-12-08 0.32.38.png

モデルで使ったデータと新しく届いたデータを比較してみます

x = df_dammy["median_house_value"].values
y = df["median_house_value"].values

# サンプルをソートする
x = np.sort(x)
y = np.sort(y)

# ヒストグラムをプロットする
plt.hist(x, alpha=0.5, label='BaseLine')  # alphaは透明度を表す
plt.hist(y, alpha=0.5, label='Newdata')  # alphaを指定することで重ねて表示できる
plt.legend()  # 凡例を表示
plt.xlabel('Value')
plt.ylabel('Frequency')
plt.title('Histogram of Baseline and Newdata')
plt.show()

平均の価格が少し少なくなり、住宅価格が高いものが増えています
スクリーンショット 2023-12-08 0.35.31.png

以前までのデータと新しく届いたデータを比較したい時に用いるのが

コルモゴロフ–スミルノフ検定になります。

コルモゴロフ-スミルノフ検定とは

コルモゴロフ–スミルノフ検定は統計学における仮説検定の一種であり、有限個の標本に基づいて、二つの母集団の確率分布が異なるものであるかどうか、あるいは母集団の確率分布が帰無仮説で提示された分布と異なっているかどうかを調べるために用いられます。

これはscipyで簡単に実装することが可能です。

仮に、有意水準を5%と設定します。

from scipy.stats import kstest
result_pvalue = kstest(x,y).pvalue
Interested_parties = 95
base_pvalue = (100-Interested_parties)/100
print(kstest(x,y))
# KstestResult(statistic=0.016521317829457427, pvalue=0.007070262373769401, statistic_location=361900.0, statistic_sign=-1)

if base_pvalue > result_pvalue:
  print("帰無仮説が棄却されます")
else:
  print("正常値です")

この場合、p値が0.0071のため、帰無仮説が棄却されます。
そのため、適切にモデルが動かなくなる恐れがあります。

とは言っても、RMSEなどを見てモデルを再学習させた方が適切かもしれません。

この分布の違いくらいであれば、Slackなどに通知したりするのがベストかもしれませんね。

まとめ

せっかく試行錯誤して作ったモデルがうまく予測できなくなったら嫌ですよね
継続的にモデルを活用すための工夫をこれ以外にも色々と考えていきたいところです。

別話(こんなものもあるよ)

日頃の業務もあり、これ以上進めることはできなかったですが

本当はAWSのSagemakerでモデルをデプロイして
lambdaでscipyを使って、条件分岐でモデル再デプロイみたいなことをしたかったです。

資格勉強が最優先なので、それが終わったら取り掛かろうかと思います。。

また、AWSにはDeeqというオープンソースライブラリがあり
データ品質を確認することができるライブラリが存在します。中身はApache Sparkで動いているとか

6
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
6
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?