tldr
KggleのWalmart Data-Retail AnalysisをWalmart Holiday Sale Prediction - Data Every Day #082に沿ってやっていきます。
実行環境はGoogle Colaboratorです。
インポート
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn.preprocessing as sp
from sklearn.model_selection import train_test_split
import sklearn.linear_model as slm
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
import tensorflow as tf
データのダウンロード
Google Driveをマウントします。
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
KaggleのAPIクライアントを初期化し、認証します。
認証情報はGoogle Drive内(/content/drive/My Drive/Colab Notebooks/Kaggle
)にkaggle.json
として置いてあります。
import os
kaggle_path = "/content/drive/My Drive/Colab Notebooks/Kaggle"
os.environ['KAGGLE_CONFIG_DIR'] = kaggle_path
from kaggle.api.kaggle_api_extended import KaggleApi
api = KaggleApi()
api.authenticate()
Kaggle APIを使ってデータをダウンロードします。
dataset_id = 'vik2012kvs/walmart-dataretail-analysis'
dataset = api.dataset_list_files(dataset_id)
file_name = dataset.files[0].name
file_path = os.path.join(api.get_default_download_dir(), file_name)
file_path
'/content/Walmart_Store_sales.csv'
api.dataset_download_file(dataset_id, file_name, force=True, quiet=False)
100%|██████████| 355k/355k [00:00<00:00, 40.9MB/s]
Downloading Walmart_Store_sales.csv to /content
True
データの読み込み
Pedumagalhaes/quality-prediction-in-a-mining-processadasを使ってダウンロードしてきたCSVファイルを読み込みます。
data = pd.read_csv(file_path)
data
Store | Date | Weekly_Sales | Holiday_Flag | Temperature | Fuel_Price | CPI | Unemployment | |
---|---|---|---|---|---|---|---|---|
0 | 1 | 05-02-2010 | 1643690.90 | 0 | 42.31 | 2.572 | 211.096358 | 8.106 |
1 | 1 | 12-02-2010 | 1641957.44 | 1 | 38.51 | 2.548 | 211.242170 | 8.106 |
2 | 1 | 19-02-2010 | 1611968.17 | 0 | 39.93 | 2.514 | 211.289143 | 8.106 |
3 | 1 | 26-02-2010 | 1409727.59 | 0 | 46.63 | 2.561 | 211.319643 | 8.106 |
4 | 1 | 05-03-2010 | 1554806.68 | 0 | 46.50 | 2.625 | 211.350143 | 8.106 |
... | ... | ... | ... | ... | ... | ... | ... | ... |
6430 | 45 | 28-09-2012 | 713173.95 | 0 | 64.88 | 3.997 | 192.013558 | 8.684 |
6431 | 45 | 05-10-2012 | 733455.07 | 0 | 64.89 | 3.985 | 192.170412 | 8.667 |
6432 | 45 | 12-10-2012 | 734464.36 | 0 | 54.47 | 4.000 | 192.327265 | 8.667 |
6433 | 45 | 19-10-2012 | 718125.53 | 0 | 56.47 | 3.969 | 192.330854 | 8.667 |
6434 | 45 | 26-10-2012 | 760281.43 | 0 | 58.85 | 3.882 | 192.308899 | 8.667 |
6435 rows × 8 columns
下準備
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6435 entries, 0 to 6434
Data columns (total 8 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Store 6435 non-null int64
1 Date 6435 non-null object
2 Weekly_Sales 6435 non-null float64
3 Holiday_Flag 6435 non-null int64
4 Temperature 6435 non-null float64
5 Fuel_Price 6435 non-null float64
6 CPI 6435 non-null float64
7 Unemployment 6435 non-null float64
dtypes: float64(5), int64(2), object(1)
memory usage: 402.3+ KB
data['Year'] = data['Date'].apply(lambda x: x[-4:])
data['Month'] = data['Date'].apply(lambda x: x[3:5])
data = data.drop('Date', axis=1)
可視化
corr = data.corr()
plt.figure(figsize=(12, 10))
sns.heatmap(corr, annot=True, vmin=-1.0, vmax=1.0)
<matplotlib.axes._subplots.AxesSubplot at 0x7f429d8ca518>
エンコード
def onehot_encode(df, column, prefix):
df = df.copy()
dummies = pd.get_dummies(df[column], prefix)
df = pd.concat([df, dummies], axis=1)
df = df.drop(column, axis=1)
return df
data = onehot_encode(data, column='Store', prefix='store')
データの分割とスケーリング
y = data['Holiday_Flag']
X = data.drop(['Holiday_Flag'], axis=1)
scaler = sp.StandardScaler()
X = pd.DataFrame(scaler.fit_transform(X), index=X.index, columns=X.columns)
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7)
トレーニング
log_model = slm.LogisticRegression()
svm_model = SVC()
dec_model = DecisionTreeClassifier()
log_model.fit(X_train, y_train)
svm_model.fit(X_train, y_train)
dec_model.fit(X_train, y_train)
DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',
max_depth=None, max_features=None, max_leaf_nodes=None,
min_impurity_decrease=0.0, min_impurity_split=None,
min_samples_leaf=1, min_samples_split=2,
min_weight_fraction_leaf=0.0, presort='deprecated',
random_state=None, splitter='best')
結果
print("Logistic Regression Accuracy:", log_model.score(X_test, y_test))
print(" SVM Accuracy:", svm_model.score(X_test, y_test))
print(" Decision Tree Accuracy:", dec_model.score(X_test, y_test))
Logistic Regression Accuracy: 0.9337131020196789
SVM Accuracy: 0.9337131020196789
Decision Tree Accuracy: 0.9528741584671155