tldr
KggleのVideo Game SalesをVideo Game Sales Prediction - Data Every Day #028に沿ってやっていきます。
実行環境はGoogle Colaboratorです。
どんなデータ
ゲームのセールスを予測します。
インポート
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn.preprocessing as sp
from sklearn.model_selection import train_test_split
import tensorflow as tf
from tensorflow_addons.metrics import RSquare
データのダウンロード
Google Driveをマウントします。
from google.colab import drive
drive.mount('/content/drive')
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
KaggleのAPIクライアントを初期化し、認証します。
認証情報はGoogle Drive内(/content/drive/My Drive/Colab Notebooks/Kaggle
)にkaggle.json
として置いてあります。
import os
kaggle_path = "/content/drive/My Drive/Colab Notebooks/Kaggle"
os.environ['KAGGLE_CONFIG_DIR'] = kaggle_path
from kaggle.api.kaggle_api_extended import KaggleApi
api = KaggleApi()
api.authenticate()
Kaggle APIを使ってデータをダウンロードします。
dataset_id = 'gregorut/videogamesales'
dataset = api.dataset_list_files(dataset_id)
file_name = dataset.files[0].name
file_path = os.path.join(api.get_default_download_dir(), file_name)
file_path
Warning: Looks like you're using an outdated API Version, please consider updating (server 1.5.10 / client 1.5.9)
'/content/vgsales.csv'
api.dataset_download_file(dataset_id, file_name, force=True, quiet=False)
100%|██████████| 381k/381k [00:00<00:00, 65.9MB/s]
Downloading vgsales.csv.zip to /content
True
import zipfile
zip_path = '/content/' + file_name + '.zip'
with zipfile.ZipFile(zip_path) as existing_zip:
existing_zip.extractall('/content')
データの読み込み
Padasを使ってダウンロードしてきたCSVファイルを読み込みます。
data = pd.read_csv(file_path, index_col='Rank')
data
Name | Platform | Year | Genre | Publisher | NA_Sales | EU_Sales | JP_Sales | Other_Sales | Global_Sales | |
---|---|---|---|---|---|---|---|---|---|---|
Rank | ||||||||||
1 | Wii Sports | Wii | 2006.0 | Sports | Nintendo | 41.49 | 29.02 | 3.77 | 8.46 | 82.74 |
2 | Super Mario Bros. | NES | 1985.0 | Platform | Nintendo | 29.08 | 3.58 | 6.81 | 0.77 | 40.24 |
3 | Mario Kart Wii | Wii | 2008.0 | Racing | Nintendo | 15.85 | 12.88 | 3.79 | 3.31 | 35.82 |
4 | Wii Sports Resort | Wii | 2009.0 | Sports | Nintendo | 15.75 | 11.01 | 3.28 | 2.96 | 33.00 |
5 | Pokemon Red/Pokemon Blue | GB | 1996.0 | Role-Playing | Nintendo | 11.27 | 8.89 | 10.22 | 1.00 | 31.37 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
16596 | Woody Woodpecker in Crazy Castle 5 | GBA | 2002.0 | Platform | Kemco | 0.01 | 0.00 | 0.00 | 0.00 | 0.01 |
16597 | Men in Black II: Alien Escape | GC | 2003.0 | Shooter | Infogrames | 0.01 | 0.00 | 0.00 | 0.00 | 0.01 |
16598 | SCORE International Baja 1000: The Official Game | PS2 | 2008.0 | Racing | Activision | 0.00 | 0.00 | 0.00 | 0.00 | 0.01 |
16599 | Know How 2 | DS | 2010.0 | Puzzle | 7G//AMES | 0.00 | 0.01 | 0.00 | 0.00 | 0.01 |
16600 | Spirits & Spells | GBA | 2003.0 | Platform | Wanadoo | 0.01 | 0.00 | 0.00 | 0.00 | 0.01 |
16598 rows × 10 columns
不要な列の削除
columns_to_drop = [
'Name',
'NA_Sales',
'EU_Sales',
'JP_Sales',
'Other_Sales'
]
data = data.drop(columns_to_drop, axis=1)
data.isnull().sum()
Platform 0
Year 271
Genre 0
Publisher 58
Global_Sales 0
dtype: int64
欠損値の処理
data['Year'] =data['Year'].fillna(data['Year'].mean())
data = data.dropna(axis=0)
data.isnull().sum()
Platform 0
Year 0
Genre 0
Publisher 0
Global_Sales 0
dtype: int64
エンコード
Onehot Features
counts = data['Publisher'].value_counts()
data['Publisher'] = data['Publisher'].apply(lambda x: 'Small Publisher' if counts[x] < 50 else x)
def onehot_encode(data, columns):
for column in columns:
dummies = pd.get_dummies(data[column])
data = pd.concat([data, dummies], axis=1)
data = data.drop(column, axis=1)
return data
onehot_columns = ['Platform', 'Genre', 'Publisher']
data = onehot_encode(data, onehot_columns)
data.head()
Year | Global_Sales | 2600 | 3DO | 3DS | DC | DS | GB | GBA | GC | GEN | GG | N64 | NES | NG | PC | PCFX | PS | PS2 | PS3 | PS4 | PSP | PSV | SAT | SCD | SNES | TG16 | WS | Wii | WiiU | X360 | XB | XOne | Action | Adventure | Fighting | Misc | Platform | Puzzle | Racing | ... | Bethesda Softworks | Capcom | Codemasters | Crave Entertainment | D3Publisher | Deep Silver | Disney Interactive Studios | Eidos Interactive | Electronic Arts | Empire Interactive | Focus Home Interactive | Hudson Soft | Idea Factory | Ignition Entertainment | Infogrames | Kadokawa Shoten | Konami Digital Entertainment | LucasArts | Majesco Entertainment | Marvelous Interactive | Microsoft Game Studios | Midway Games | Namco Bandai Games | Nintendo | Nippon Ichi Software | Rising Star Games | Sega | Small Publisher | Sony Computer Entertainment | Square Enix | SquareSoft | THQ | Take-Two Interactive | Tecmo Koei | Ubisoft | Unknown | Virgin Interactive | Vivendi Games | Warner Bros. Interactive Entertainment | Zoo Digital Publishing | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Rank | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
1 | 2006.0 | 82.74 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
2 | 1985.0 | 40.24 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
3 | 2008.0 | 35.82 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
4 | 2009.0 | 33.00 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
5 | 1996.0 | 31.37 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | ... | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 |
5 rows × 92 columns
スケーリング
y = data['Global_Sales']
X = data.drop('Global_Sales', axis=1)
scaler = sp.StandardScaler()
X = scaler.fit_transform(X)
pd.DataFrame(X).head()
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | ... | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | -0.070070 | -0.090035 | -0.013469 | -0.177827 | -0.056159 | -0.387155 | -0.077203 | -0.224853 | -0.186507 | -0.040436 | -0.007776 | -0.140235 | -0.077203 | -0.026945 | -0.247404 | -0.007776 | -0.27881 | -0.387464 | -0.295223 | -0.143999 | -0.28107 | -0.159631 | -0.102811 | -0.01905 | -0.121085 | -0.010997 | -0.01905 | 3.390051 | -0.093387 | -0.287283 | -0.228977 | -0.114219 | -0.500094 | -0.289865 | -0.232177 | -0.33979 | -0.237621 | -0.190803 | -0.285677 | -0.314184 | ... | -0.065659 | -0.153552 | -0.096307 | -0.065659 | -0.106065 | -0.086202 | -0.115569 | -0.110073 | -0.298238 | -0.056159 | -0.059321 | -0.070152 | -0.08866 | -0.060841 | -0.06134 | -0.055065 | -0.230145 | -0.073967 | -0.074789 | -0.058286 | -0.107512 | -0.110073 | -0.244362 | 4.746339 | -0.07993 | -0.072296 | -0.200465 | -0.50321 | -0.207539 | -0.119534 | -0.056159 | -0.21256 | -0.160029 | -0.144436 | -0.24283 | -0.111471 | -0.06134 | -0.100073 | -0.119273 | -0.079546 |
1 | -3.698162 | -0.090035 | -0.013469 | -0.177827 | -0.056159 | -0.387155 | -0.077203 | -0.224853 | -0.186507 | -0.040436 | -0.007776 | -0.140235 | 12.952819 | -0.026945 | -0.247404 | -0.007776 | -0.27881 | -0.387464 | -0.295223 | -0.143999 | -0.28107 | -0.159631 | -0.102811 | -0.01905 | -0.121085 | -0.010997 | -0.01905 | -0.294981 | -0.093387 | -0.287283 | -0.228977 | -0.114219 | -0.500094 | -0.289865 | -0.232177 | -0.33979 | 4.208373 | -0.190803 | -0.285677 | -0.314184 | ... | -0.065659 | -0.153552 | -0.096307 | -0.065659 | -0.106065 | -0.086202 | -0.115569 | -0.110073 | -0.298238 | -0.056159 | -0.059321 | -0.070152 | -0.08866 | -0.060841 | -0.06134 | -0.055065 | -0.230145 | -0.073967 | -0.074789 | -0.058286 | -0.107512 | -0.110073 | -0.244362 | 4.746339 | -0.07993 | -0.072296 | -0.200465 | -0.50321 | -0.207539 | -0.119534 | -0.056159 | -0.21256 | -0.160029 | -0.144436 | -0.24283 | -0.111471 | -0.06134 | -0.100073 | -0.119273 | -0.079546 |
2 | 0.275463 | -0.090035 | -0.013469 | -0.177827 | -0.056159 | -0.387155 | -0.077203 | -0.224853 | -0.186507 | -0.040436 | -0.007776 | -0.140235 | -0.077203 | -0.026945 | -0.247404 | -0.007776 | -0.27881 | -0.387464 | -0.295223 | -0.143999 | -0.28107 | -0.159631 | -0.102811 | -0.01905 | -0.121085 | -0.010997 | -0.01905 | 3.390051 | -0.093387 | -0.287283 | -0.228977 | -0.114219 | -0.500094 | -0.289865 | -0.232177 | -0.33979 | -0.237621 | -0.190803 | 3.500458 | -0.314184 | ... | -0.065659 | -0.153552 | -0.096307 | -0.065659 | -0.106065 | -0.086202 | -0.115569 | -0.110073 | -0.298238 | -0.056159 | -0.059321 | -0.070152 | -0.08866 | -0.060841 | -0.06134 | -0.055065 | -0.230145 | -0.073967 | -0.074789 | -0.058286 | -0.107512 | -0.110073 | -0.244362 | 4.746339 | -0.07993 | -0.072296 | -0.200465 | -0.50321 | -0.207539 | -0.119534 | -0.056159 | -0.21256 | -0.160029 | -0.144436 | -0.24283 | -0.111471 | -0.06134 | -0.100073 | -0.119273 | -0.079546 |
3 | 0.448229 | -0.090035 | -0.013469 | -0.177827 | -0.056159 | -0.387155 | -0.077203 | -0.224853 | -0.186507 | -0.040436 | -0.007776 | -0.140235 | -0.077203 | -0.026945 | -0.247404 | -0.007776 | -0.27881 | -0.387464 | -0.295223 | -0.143999 | -0.28107 | -0.159631 | -0.102811 | -0.01905 | -0.121085 | -0.010997 | -0.01905 | 3.390051 | -0.093387 | -0.287283 | -0.228977 | -0.114219 | -0.500094 | -0.289865 | -0.232177 | -0.33979 | -0.237621 | -0.190803 | -0.285677 | -0.314184 | ... | -0.065659 | -0.153552 | -0.096307 | -0.065659 | -0.106065 | -0.086202 | -0.115569 | -0.110073 | -0.298238 | -0.056159 | -0.059321 | -0.070152 | -0.08866 | -0.060841 | -0.06134 | -0.055065 | -0.230145 | -0.073967 | -0.074789 | -0.058286 | -0.107512 | -0.110073 | -0.244362 | 4.746339 | -0.07993 | -0.072296 | -0.200465 | -0.50321 | -0.207539 | -0.119534 | -0.056159 | -0.21256 | -0.160029 | -0.144436 | -0.24283 | -0.111471 | -0.06134 | -0.100073 | -0.119273 | -0.079546 |
4 | -1.797732 | -0.090035 | -0.013469 | -0.177827 | -0.056159 | -0.387155 | 12.952819 | -0.224853 | -0.186507 | -0.040436 | -0.007776 | -0.140235 | -0.077203 | -0.026945 | -0.247404 | -0.007776 | -0.27881 | -0.387464 | -0.295223 | -0.143999 | -0.28107 | -0.159631 | -0.102811 | -0.01905 | -0.121085 | -0.010997 | -0.01905 | -0.294981 | -0.093387 | -0.287283 | -0.228977 | -0.114219 | -0.500094 | -0.289865 | -0.232177 | -0.33979 | -0.237621 | -0.190803 | -0.285677 | 3.182853 | ... | -0.065659 | -0.153552 | -0.096307 | -0.065659 | -0.106065 | -0.086202 | -0.115569 | -0.110073 | -0.298238 | -0.056159 | -0.059321 | -0.070152 | -0.08866 | -0.060841 | -0.06134 | -0.055065 | -0.230145 | -0.073967 | -0.074789 | -0.058286 | -0.107512 | -0.110073 | -0.244362 | 4.746339 | -0.07993 | -0.072296 | -0.200465 | -0.50321 | -0.207539 | -0.119534 | -0.056159 | -0.21256 | -0.160029 | -0.144436 | -0.24283 | -0.111471 | -0.06134 | -0.100073 | -0.119273 | -0.079546 |
5 rows × 91 columns
トレーニング
X.shape
(16540, 91)
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8)
model = tf.keras.Sequential([
tf.keras.layers.Dense(128, activation='relu', input_shape=(91,)),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dense(1),
])
model.summary()
Model: "sequential_3"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_9 (Dense) (None, 128) 11776
_________________________________________________________________
dense_10 (Dense) (None, 128) 16512
_________________________________________________________________
dense_11 (Dense) (None, 1) 129
=================================================================
Total params: 28,417
Trainable params: 28,417
Non-trainable params: 0
_________________________________________________________________
model.compile(
optimizer=tf.keras.optimizers.RMSprop(0.001),
loss='mse',
)
batch_size = 64
epochs = 100
history = model.fit(
X_train,
y_train,
validation_split=0.2,
batch_size=batch_size,
epochs=epochs,
verbose=0,
)
結果
plt.figure(figsize=(14, 10))
epochs_range = range(1, epochs + 1)
train_loss = history.history['loss']
val_loss = history.history['val_loss']
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.xlabel('Epoch')
plt.xlabel('Loss')
plt.show()
y_pred = np.squeeze(model.predict(X_test))
result = RSquare()
result.update_state(y_test, y_pred)
print(f'R^2 Score:', result.result())
R^2 Score: tf.Tensor(0.0019521117, shape=(), dtype=float32)