tldr
KggleのQuality Prediction in a Mining ProcessをMining Quality Prediction - Data Every Day #039に沿ってやっていきます。
実行環境はGoogle Colaboratorです。
インポート
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn.preprocessing as sp
from sklearn.model_selection import train_test_split
import sklearn.linear_model as slm
import tensorflow as tf
データのダウンロード
Google Driveをマウントします。
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
KaggleのAPIクライアントを初期化し、認証します。
認証情報はGoogle Drive内(/content/drive/My Drive/Colab Notebooks/Kaggle
)にkaggle.json
として置いてあります。
import os
kaggle_path = "/content/drive/My Drive/Colab Notebooks/Kaggle"
os.environ['KAGGLE_CONFIG_DIR'] = kaggle_path
from kaggle.api.kaggle_api_extended import KaggleApi
api = KaggleApi()
api.authenticate()
Kaggle APIを使ってデータをダウンロードします。
dataset_id = 'edumagalhaes/quality-prediction-in-a-mining-process'
dataset = api.dataset_list_files(dataset_id)
file_name = dataset.files[0].name
file_path = os.path.join(api.get_default_download_dir(), file_name)
file_path
'/content/MiningProcess_Flotation_Plant_Database.csv'
api.dataset_download_file(dataset_id, file_name, force=True, quiet=False)
10%|▉ | 5.00M/50.9M [00:00<00:00, 50.0MB/s]
Downloading MiningProcess_Flotation_Plant_Database.csv.zip to /content
100%|██████████| 50.9M/50.9M [00:00<00:00, 87.3MB/s]
True
データの読み込み
Pedumagalhaes/quality-prediction-in-a-mining-processadasを使ってダウンロードしてきたCSVファイルを読み込みます。
data = pd.read_csv(file_path+'.zip')
data
date | % Iron Feed | % Silica Feed | Starch Flow | Amina Flow | Ore Pulp Flow | Ore Pulp pH | Ore Pulp Density | Flotation Column 01 Air Flow | Flotation Column 02 Air Flow | Flotation Column 03 Air Flow | Flotation Column 04 Air Flow | Flotation Column 05 Air Flow | Flotation Column 06 Air Flow | Flotation Column 07 Air Flow | Flotation Column 01 Level | Flotation Column 02 Level | Flotation Column 03 Level | Flotation Column 04 Level | Flotation Column 05 Level | Flotation Column 06 Level | Flotation Column 07 Level | % Iron Concentrate | % Silica Concentrate | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 2017-03-10 01:00:00 | 55,2 | 16,98 | 3019,53 | 557,434 | 395,713 | 10,0664 | 1,74 | 249,214 | 253,235 | 250,576 | 295,096 | 306,4 | 250,225 | 250,884 | 457,396 | 432,962 | 424,954 | 443,558 | 502,255 | 446,37 | 523,344 | 66,91 | 1,31 |
1 | 2017-03-10 01:00:00 | 55,2 | 16,98 | 3024,41 | 563,965 | 397,383 | 10,0672 | 1,74 | 249,719 | 250,532 | 250,862 | 295,096 | 306,4 | 250,137 | 248,994 | 451,891 | 429,56 | 432,939 | 448,086 | 496,363 | 445,922 | 498,075 | 66,91 | 1,31 |
2 | 2017-03-10 01:00:00 | 55,2 | 16,98 | 3043,46 | 568,054 | 399,668 | 10,068 | 1,74 | 249,741 | 247,874 | 250,313 | 295,096 | 306,4 | 251,345 | 248,071 | 451,24 | 468,927 | 434,61 | 449,688 | 484,411 | 447,826 | 458,567 | 66,91 | 1,31 |
3 | 2017-03-10 01:00:00 | 55,2 | 16,98 | 3047,36 | 568,665 | 397,939 | 10,0689 | 1,74 | 249,917 | 254,487 | 250,049 | 295,096 | 306,4 | 250,422 | 251,147 | 452,441 | 458,165 | 442,865 | 446,21 | 471,411 | 437,69 | 427,669 | 66,91 | 1,31 |
4 | 2017-03-10 01:00:00 | 55,2 | 16,98 | 3033,69 | 558,167 | 400,254 | 10,0697 | 1,74 | 250,203 | 252,136 | 249,895 | 295,096 | 306,4 | 249,983 | 248,928 | 452,441 | 452,9 | 450,523 | 453,67 | 462,598 | 443,682 | 425,679 | 66,91 | 1,31 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
737448 | 2017-09-09 23:00:00 | 49,75 | 23,2 | 2710,94 | 441,052 | 386,57 | 9,62129 | 1,65365 | 302,344 | 298,786 | 299,163 | 299,92 | 299,623 | 346,794 | 313,695 | 392,16 | 430,702 | 872,008 | 418,725 | 497,548 | 446,357 | 416,892 | 64,27 | 1,71 |
737449 | 2017-09-09 23:00:00 | 49,75 | 23,2 | 2692,01 | 473,436 | 384,939 | 9,62063 | 1,65352 | 303,013 | 301,879 | 299,487 | 299,71 | 300,465 | 330,023 | 236,7 | 401,505 | 404,616 | 864,409 | 418,377 | 506,398 | 372,995 | 426,337 | 64,27 | 1,71 |
737450 | 2017-09-09 23:00:00 | 49,75 | 23,2 | 2692,2 | 500,488 | 383,496 | 9,61874 | 1,65338 | 303,662 | 307,397 | 299,487 | 299,927 | 299,707 | 329,59 | 225,879 | 408,899 | 399,316 | 867,598 | 419,531 | 503,414 | 336,035 | 433,13 | 64,27 | 1,71 |
737451 | 2017-09-09 23:00:00 | 49,75 | 23,2 | 1164,12 | 491,548 | 384,976 | 9,61686 | 1,65324 | 302,55 | 301,959 | 298,045 | 299,372 | 298,819 | 351,453 | 308,115 | 405,107 | 466,832 | 876,591 | 407,299 | 502,301 | 340,844 | 433,966 | 64,27 | 1,71 |
737452 | 2017-09-09 23:00:00 | 49,75 | 23,2 | 1164,12 | 468,019 | 384,801 | 9,61497 | 1,6531 | 300,355 | 292,865 | 298,625 | 298,717 | 297,395 | 362,464 | 308,115 | 413,754 | 514,143 | 881,323 | 378,969 | 500,1 | 374,354 | 441,182 | 64,27 | 1,71 |
737453 rows × 24 columns
下準備
for column in data.columns:
data[column] = data[column].apply(lambda x: x.replace(',', '.'))
import re
data['date'] = data['date'].apply(lambda x: re.search('[0-9]*-[0-9]*', x).group(0))
data['Year'] = data['date'].apply(lambda x: re.search('^[^-]*', x).group(0))
data['Month'] = data['date'].apply(lambda x: re.search('[^-]*$', x).group(0))
data = data.drop('date', axis=1)
data
% Iron Feed | % Silica Feed | Starch Flow | Amina Flow | Ore Pulp Flow | Ore Pulp pH | Ore Pulp Density | Flotation Column 01 Air Flow | Flotation Column 02 Air Flow | Flotation Column 03 Air Flow | Flotation Column 04 Air Flow | Flotation Column 05 Air Flow | Flotation Column 06 Air Flow | Flotation Column 07 Air Flow | Flotation Column 01 Level | Flotation Column 02 Level | Flotation Column 03 Level | Flotation Column 04 Level | Flotation Column 05 Level | Flotation Column 06 Level | Flotation Column 07 Level | % Iron Concentrate | % Silica Concentrate | Year | Month | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 55.2 | 16.98 | 3019.53 | 557.434 | 395.713 | 10.0664 | 1.74 | 249.214 | 253.235 | 250.576 | 295.096 | 306.4 | 250.225 | 250.884 | 457.396 | 432.962 | 424.954 | 443.558 | 502.255 | 446.37 | 523.344 | 66.91 | 1.31 | 2017 | 03 |
1 | 55.2 | 16.98 | 3024.41 | 563.965 | 397.383 | 10.0672 | 1.74 | 249.719 | 250.532 | 250.862 | 295.096 | 306.4 | 250.137 | 248.994 | 451.891 | 429.56 | 432.939 | 448.086 | 496.363 | 445.922 | 498.075 | 66.91 | 1.31 | 2017 | 03 |
2 | 55.2 | 16.98 | 3043.46 | 568.054 | 399.668 | 10.068 | 1.74 | 249.741 | 247.874 | 250.313 | 295.096 | 306.4 | 251.345 | 248.071 | 451.24 | 468.927 | 434.61 | 449.688 | 484.411 | 447.826 | 458.567 | 66.91 | 1.31 | 2017 | 03 |
3 | 55.2 | 16.98 | 3047.36 | 568.665 | 397.939 | 10.0689 | 1.74 | 249.917 | 254.487 | 250.049 | 295.096 | 306.4 | 250.422 | 251.147 | 452.441 | 458.165 | 442.865 | 446.21 | 471.411 | 437.69 | 427.669 | 66.91 | 1.31 | 2017 | 03 |
4 | 55.2 | 16.98 | 3033.69 | 558.167 | 400.254 | 10.0697 | 1.74 | 250.203 | 252.136 | 249.895 | 295.096 | 306.4 | 249.983 | 248.928 | 452.441 | 452.9 | 450.523 | 453.67 | 462.598 | 443.682 | 425.679 | 66.91 | 1.31 | 2017 | 03 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
737448 | 49.75 | 23.2 | 2710.94 | 441.052 | 386.57 | 9.62129 | 1.65365 | 302.344 | 298.786 | 299.163 | 299.92 | 299.623 | 346.794 | 313.695 | 392.16 | 430.702 | 872.008 | 418.725 | 497.548 | 446.357 | 416.892 | 64.27 | 1.71 | 2017 | 09 |
737449 | 49.75 | 23.2 | 2692.01 | 473.436 | 384.939 | 9.62063 | 1.65352 | 303.013 | 301.879 | 299.487 | 299.71 | 300.465 | 330.023 | 236.7 | 401.505 | 404.616 | 864.409 | 418.377 | 506.398 | 372.995 | 426.337 | 64.27 | 1.71 | 2017 | 09 |
737450 | 49.75 | 23.2 | 2692.2 | 500.488 | 383.496 | 9.61874 | 1.65338 | 303.662 | 307.397 | 299.487 | 299.927 | 299.707 | 329.59 | 225.879 | 408.899 | 399.316 | 867.598 | 419.531 | 503.414 | 336.035 | 433.13 | 64.27 | 1.71 | 2017 | 09 |
737451 | 49.75 | 23.2 | 1164.12 | 491.548 | 384.976 | 9.61686 | 1.65324 | 302.55 | 301.959 | 298.045 | 299.372 | 298.819 | 351.453 | 308.115 | 405.107 | 466.832 | 876.591 | 407.299 | 502.301 | 340.844 | 433.966 | 64.27 | 1.71 | 2017 | 09 |
737452 | 49.75 | 23.2 | 1164.12 | 468.019 | 384.801 | 9.61497 | 1.6531 | 300.355 | 292.865 | 298.625 | 298.717 | 297.395 | 362.464 | 308.115 | 413.754 | 514.143 | 881.323 | 378.969 | 500.1 | 374.354 | 441.182 | 64.27 | 1.71 | 2017 | 09 |
737453 rows × 25 columns
data['Year'].unique()
array(['2017'], dtype=object)
data = data.drop('Year', axis=1)
データの分割
target = '% Silica Concentrate'
y = data[target]
X_n = data.drop([target, '% Iron Concentrate'], axis=1)
X_i = data.drop([target], axis=1)
スケーリング
scaler = sp.StandardScaler()
X_n = scaler.fit_transform(X_n)
X_i = scaler.fit_transform(X_i)
X_n_train, X_n_test, y_n_train, y_n_test = train_test_split(X_n, y, train_size=0.7)
X_i_train, X_i_test, y_i_train, y_i_test = train_test_split(X_i, y, train_size=0.7)
トレーニング
model_n = slm.LinearRegression()
model_i = slm.LinearRegression()
model_n.fit(X_n_train, y_n_train)
print('Model without iron R^2 Score:', model_n.score(X_n_test, y_n_test))
Model without iron R^2 Score: 0.15409635166200997
model_i.fit(X_i_train, y_i_train)
print('Model with iron R^2 Score:', model_i.score(X_i_test, y_i_test))
Model with iron R^2 Score: 0.6875874592764709