LoginSignup
1

More than 3 years have passed since last update.

Data Every Day: ウォルマートリテール分析

Posted at

tldr

KggleのWalmart Data-Retail AnalysisWalmart Holiday Sale Prediction - Data Every Day #082に沿ってやっていきます。

実行環境はGoogle Colaboratorです。

インポート

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import sklearn.preprocessing as sp
from sklearn.model_selection import train_test_split
import sklearn.linear_model as slm
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier

import tensorflow as tf

データのダウンロード

Google Driveをマウントします。

from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive

KaggleのAPIクライアントを初期化し、認証します。
認証情報はGoogle Drive内(/content/drive/My Drive/Colab Notebooks/Kaggle)にkaggle.jsonとして置いてあります。

import os
kaggle_path = "/content/drive/My Drive/Colab Notebooks/Kaggle"
os.environ['KAGGLE_CONFIG_DIR'] = kaggle_path

from kaggle.api.kaggle_api_extended import KaggleApi
api = KaggleApi()
api.authenticate() 

Kaggle APIを使ってデータをダウンロードします。

dataset_id = 'vik2012kvs/walmart-dataretail-analysis'
dataset = api.dataset_list_files(dataset_id)
file_name = dataset.files[0].name
file_path = os.path.join(api.get_default_download_dir(), file_name)
file_path
'/content/Walmart_Store_sales.csv'
api.dataset_download_file(dataset_id, file_name, force=True, quiet=False)
100%|██████████| 355k/355k [00:00<00:00, 40.9MB/s]

Downloading Walmart_Store_sales.csv to /content









True

データの読み込み

Pedumagalhaes/quality-prediction-in-a-mining-processadasを使ってダウンロードしてきたCSVファイルを読み込みます。

data = pd.read_csv(file_path)
data
Store Date Weekly_Sales Holiday_Flag Temperature Fuel_Price CPI Unemployment
0 1 05-02-2010 1643690.90 0 42.31 2.572 211.096358 8.106
1 1 12-02-2010 1641957.44 1 38.51 2.548 211.242170 8.106
2 1 19-02-2010 1611968.17 0 39.93 2.514 211.289143 8.106
3 1 26-02-2010 1409727.59 0 46.63 2.561 211.319643 8.106
4 1 05-03-2010 1554806.68 0 46.50 2.625 211.350143 8.106
... ... ... ... ... ... ... ... ...
6430 45 28-09-2012 713173.95 0 64.88 3.997 192.013558 8.684
6431 45 05-10-2012 733455.07 0 64.89 3.985 192.170412 8.667
6432 45 12-10-2012 734464.36 0 54.47 4.000 192.327265 8.667
6433 45 19-10-2012 718125.53 0 56.47 3.969 192.330854 8.667
6434 45 26-10-2012 760281.43 0 58.85 3.882 192.308899 8.667

6435 rows × 8 columns

下準備

data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6435 entries, 0 to 6434
Data columns (total 8 columns):
 #   Column        Non-Null Count  Dtype  
---  ------        --------------  -----  
 0   Store         6435 non-null   int64  
 1   Date          6435 non-null   object 
 2   Weekly_Sales  6435 non-null   float64
 3   Holiday_Flag  6435 non-null   int64  
 4   Temperature   6435 non-null   float64
 5   Fuel_Price    6435 non-null   float64
 6   CPI           6435 non-null   float64
 7   Unemployment  6435 non-null   float64
dtypes: float64(5), int64(2), object(1)
memory usage: 402.3+ KB
data['Year'] = data['Date'].apply(lambda x: x[-4:])
data['Month'] = data['Date'].apply(lambda x: x[3:5])

data = data.drop('Date', axis=1)

可視化

corr = data.corr()
plt.figure(figsize=(12, 10))
sns.heatmap(corr, annot=True, vmin=-1.0, vmax=1.0)
<matplotlib.axes._subplots.AxesSubplot at 0x7f429d8ca518>

png

エンコード

def onehot_encode(df, column, prefix):
    df = df.copy()
    dummies = pd.get_dummies(df[column], prefix)
    df = pd.concat([df, dummies], axis=1)
    df = df.drop(column, axis=1)
    return df
data = onehot_encode(data, column='Store', prefix='store')

データの分割とスケーリング

y = data['Holiday_Flag']
X = data.drop(['Holiday_Flag'], axis=1)
scaler = sp.StandardScaler()
X = pd.DataFrame(scaler.fit_transform(X), index=X.index, columns=X.columns)
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.7)

トレーニング

log_model = slm.LogisticRegression()
svm_model = SVC()
dec_model = DecisionTreeClassifier()
log_model.fit(X_train, y_train)
svm_model.fit(X_train, y_train)
dec_model.fit(X_train, y_train)
DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',
                       max_depth=None, max_features=None, max_leaf_nodes=None,
                       min_impurity_decrease=0.0, min_impurity_split=None,
                       min_samples_leaf=1, min_samples_split=2,
                       min_weight_fraction_leaf=0.0, presort='deprecated',
                       random_state=None, splitter='best')

結果

print("Logistic Regression Accuracy:", log_model.score(X_test, y_test))
print("                SVM Accuracy:", svm_model.score(X_test, y_test))
print("      Decision Tree Accuracy:", dec_model.score(X_test, y_test))
Logistic Regression Accuracy: 0.9337131020196789
                SVM Accuracy: 0.9337131020196789
      Decision Tree Accuracy: 0.9528741584671155

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1