0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

レンタルバイクの予測ダヨーン

Posted at

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import mean_squared_error, r2_score
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import lightgbm as lgb
import logging

LightGBMのログ出力をエラーのみ表示

logging.getLogger('lightgbm').setLevel(logging.ERROR)

データの読み込み

train_data = pd.read_csv('hour_train.csv')
test_data = pd.read_csv('hour_test.csv')

新しい特徴量の作成

train_data['hr_8_17_18'] = train_data['hr'].apply(lambda x: 1 if x in [8, 17, 18] else 0)
test_data['hr_8_17_18'] = test_data['hr'].apply(lambda x: 1 if x in [8, 17, 18] else 0)

train_data['atemp_0.2_0.8'] = train_data['atemp'].apply(lambda x: 1 if 0.2 <= x <= 0.8 else 0)
test_data['atemp_0.2_0.8'] = test_data['atemp'].apply(lambda x: 1 if 0.2 <= x <= 0.8 else 0)

train_data['temp_0.2_0.8'] = train_data['temp'].apply(lambda x: 1 if 0.2 <= x <= 0.8 else 0)
test_data['temp_0.2_0.8'] = test_data['temp'].apply(lambda x: 1 if 0.2 <= x <= 0.8 else 0)

train_data['hum_above_0.2'] = train_data['hum'].apply(lambda x: 1 if x >= 0.2 else 0)
test_data['hum_above_0.2'] = test_data['hum'].apply(lambda x: 1 if x >= 0.2 else 0)

train_data['windspeed_below_0.6'] = train_data['windspeed'].apply(lambda x: 1 if x <= 0.6 else 0)
test_data['windspeed_below_0.6'] = test_data['windspeed'].apply(lambda x: 1 if x <= 0.6 else 0)

train_data['hr_7_20'] = train_data['hr'].apply(lambda x: 1 if 7 <= x <= 20 else 0)
test_data['hr_7_20'] = test_data['hr'].apply(lambda x: 1 if 7 <= x <= 20 else 0)

'season'と'weathersit'をダミー変数に変換

train_data = pd.get_dummies(train_data, columns=['season', 'weathersit'], drop_first=True)
test_data = pd.get_dummies(test_data, columns=['season', 'weathersit'], drop_first=True)

トレーニングデータとテストデータでのダミー変数の列を一致させる

train_columns = set(train_data.columns)
test_columns = set(test_data.columns)
missing_in_test = train_columns - test_columns
missing_in_train = test_columns - train_columns

for col in missing_in_test:
test_data[col] = 0
for col in missing_in_train:
train_data[col] = 0

列を揃えるために並べ替え

train_data = train_data.sort_index(axis=1)
test_data = test_data.sort_index(axis=1)

特徴量とターゲット変数の設定

features = ['temp', 'atemp', 'yr', 'mnth', 'hr', 'hum', 'holiday', 'workingday',
'hr_8_17_18', 'atemp_0.2_0.8', 'temp_0.2_0.8', 'hum_above_0.2', 'hr_7_20'] +
[col for col in train_data.columns if col.startswith('season_') or col.startswith('weathersit_')]
target = 'cnt'

X_train = train_data[features]
y_train = train_data[target]

データの標準化

scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)

LightGBMモデルの作成

model = lgb.LGBMRegressor(verbose=-1, random_state=42)

クロスバリデーションによるモデルの評価

cv_scores = cross_val_score(model, X_train_scaled, y_train, cv=5, scoring='neg_mean_squared_error')
mean_cv_mse = -cv_scores.mean()
mean_cv_r2 = cross_val_score(model, X_train_scaled, y_train, cv=5, scoring='r2').mean()

print(f'Cross-Validation Mean Squared Error: {mean_cv_mse:.2f}')
print(f'Cross-Validation R^2 Score: {mean_cv_r2:.2f}')

データの分割(トレーニングとバリデーション用)

X_train_split, X_val_split, y_train_split, y_val_split = train_test_split(X_train_scaled, y_train, test_size=0.2, random_state=42)

モデルのトレーニング

model.fit(X_train_split, y_train_split)

バリデーションデータでの予測

y_val_pred = model.predict(X_val_split)

バリデーションデータでの評価

print(f'Validation Mean Squared Error: {mean_squared_error(y_val_split, y_val_pred):.2f}')
print(f'Validation R^2 Score: {r2_score(y_val_split, y_val_pred):.2f}')

テストデータでの予測

X_test = test_data[features]
X_test_scaled = scaler.transform(X_test)
y_test_pred = model.predict(X_test_scaled)

テストデータの実際の値

y_test_actual = test_data[target]

テストデータの評価

test_mse = mean_squared_error(y_test_actual, y_test_pred)
test_r2 = r2_score(y_test_actual, y_test_pred)

print(f'Test Mean Squared Error: {test_mse:.2f}')
print(f'Test R^2 Score: {test_r2:.2f}')

時系列ごとの実測値と予測値の折れ線グラフを作成

test_data['y_test_pred'] = y_test_pred

時間帯ごとの実測値と予測値を計算

hourly_avg_actual = test_data.groupby('hr')[target].mean()
hourly_avg_pred = test_data.groupby('hr')['y_test_pred'].mean()

時間帯ごとの予測値と実測値の折れ線グラフを作成

plt.figure(figsize=(12, 6))
plt.plot(hourly_avg_actual.index, hourly_avg_actual, label='Actual', marker='o')
plt.plot(hourly_avg_pred.index, hourly_avg_pred, label='Predicted', marker='o')
plt.xlabel('Hour of the Day')
plt.ylabel('Average Count')
plt.title('Average Actual and Predicted Count by Hour of the Day')
plt.legend()
plt.grid(True)
plt.show()

dtedayごとの実測値と予測値の平均を計算

test_data['dteday'] = pd.to_datetime(test_data['dteday']) # dtedayをdatetime型に変換
daily_avg_actual = test_data.groupby('dteday')[target].mean()
daily_avg_pred = test_data.groupby('dteday')['y_test_pred'].mean()
daily_workingday = test_data.groupby('dteday')['workingday'].first()

dtedayごとの実測値と予測値の折れ線グラフを作成

plt.figure(figsize=(14, 7))
plt.plot(daily_avg_actual.index, daily_avg_actual, label='Actual', marker='o')
plt.plot(daily_avg_pred.index, daily_avg_pred, label='Predicted', marker='o')
plt.xlabel('Date')
plt.ylabel('Average Count')
plt.title('Average Actual and Predicted Count by Date')
plt.legend()
plt.grid(True)

workingdayの情報を横軸に追加

for date, is_workingday in daily_workingday.items():
plt.axvline(x=date, color='r' if is_workingday else 'b', linestyle='--', alpha=0.5)

plt.xticks(rotation=45)
plt.show()

1時間ごとの実測値と予測値の差の平均を算出

test_data['error'] = test_data[target] - test_data['y_test_pred']
hourly_avg_error = test_data.groupby('hr')['error'].mean()

1日の実測値と予測値の差の平均を算出

daily_avg_error = test_data.groupby('dteday')['error'].mean()

差の平均を表示

print("Hourly Average Error:")
print(hourly_avg_error)

print("\nDaily Average Error:")
print(daily_avg_error)

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?