目次
初めに
2020年に発表され話題となっているTabNetという手法を用いて、Kaggleの回帰問題の入門である、住宅価格予測に挑戦しました。
TabNetとは
論文 - https://arxiv.org/pdf/1908.07442.pdf
GitHub - https://github.com/dreamquark-ai/tabnet
TabNetの特徴
TabNetの特徴としては以下のようなことが挙げられます。
- テーブルデータであれば、分類・回帰問題に対して他の学習モデルを上回る、または同等の精度。
- 深層学習を使っているが解釈可能なモデル
- 特徴量生成を行うことなく高性能なモデルが作成できる
何をしているか
TabNetで行われることは以下のようになります。
- 与えられたデータに対してマスクをかけ、教師なしの事前学習を行う
- 事前学習の結果から転位学習を行い、予測をする。
この事前学習の結果を用いることで、特徴量の重要度の解釈が行えます。
- 生の数値特徴量が与えられるとバッチ正規化を行う。
- 次に、n_stepの特徴選択ステップがある。i番目のステップはi-1番目のステップから処理された情報を入力して特徴選択を行ってい情報を集約する。
実装
それでは、TabNetを用いて住宅価格予測問題に挑戦していきます。
前処理の部分で参考にしたコード - https://www.kaggle.com/serigne/stacked-regressions-top-4-on-leaderboard
モジュールのインポート
今回使用するモジュールをインポートします
import csv
import pandas as pd
import numpy as np
import seaborn as sns
from scipy.stats import norm
from scipy import stats
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import mean_squared_error, mean_absolute_error
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from pytorch_tabnet.pretraining import TabNetPretrainer
from pytorch_tabnet.tab_model import TabNetRegressor
import torch
pytorch_tabnetをインストールしていない人は
$ pip install pytorch-tabnet
でインストールしておきましょう。
前処理
次にデータをダウンロードし、前処理を行います。
train_df = pd.read_csv('./data/train.csv')
test_df = pd.read_csv('./data/test.csv')
中身を確認すると、
train_df.head()
Id | MSSubClass | MSZoning | LotFrontage | LotArea | Street | Alley | LotShape | LandContour | Utilities | LotConfig | LandSlope | Neighborhood | Condition1 | Condition2 | BldgType | HouseStyle | OverallQual | OverallCond | YearBuilt | YearRemodAdd | RoofStyle | RoofMatl | Exterior1st | Exterior2nd | MasVnrType | MasVnrArea | ExterQual | ExterCond | Foundation | BsmtQual | BsmtCond | BsmtExposure | BsmtFinType1 | BsmtFinSF1 | BsmtFinType2 | BsmtFinSF2 | BsmtUnfSF | TotalBsmtSF | Heating | ... | CentralAir | Electrical | 1stFlrSF | 2ndFlrSF | LowQualFinSF | GrLivArea | BsmtFullBath | BsmtHalfBath | FullBath | HalfBath | BedroomAbvGr | KitchenAbvGr | KitchenQual | TotRmsAbvGrd | Functional | Fireplaces | FireplaceQu | GarageType | GarageYrBlt | GarageFinish | GarageCars | GarageArea | GarageQual | GarageCond | PavedDrive | WoodDeckSF | OpenPorchSF | EnclosedPorch | 3SsnPorch | ScreenPorch | PoolArea | PoolQC | Fence | MiscFeature | MiscVal | MoSold | YrSold | SaleType | SaleCondition | SalePrice | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | 60 | RL | 65.0 | 8450 | Pave | NaN | Reg | Lvl | AllPub | Inside | Gtl | CollgCr | Norm | Norm | 1Fam | 2Story | 7 | 5 | 2003 | 2003 | Gable | CompShg | VinylSd | VinylSd | BrkFace | 196.0 | Gd | TA | PConc | Gd | TA | No | GLQ | 706 | Unf | 0 | 150 | 856 | GasA | ... | Y | SBrkr | 856 | 854 | 0 | 1710 | 1 | 0 | 2 | 1 | 3 | 1 | Gd | 8 | Typ | 0 | NaN | Attchd | 2003.0 | RFn | 2 | 548 | TA | TA | Y | 0 | 61 | 0 | 0 | 0 | 0 | NaN | NaN | NaN | 0 | 2 | 2008 | WD | Normal | 208500 |
1 | 2 | 20 | RL | 80.0 | 9600 | Pave | NaN | Reg | Lvl | AllPub | FR2 | Gtl | Veenker | Feedr | Norm | 1Fam | 1Story | 6 | 8 | 1976 | 1976 | Gable | CompShg | MetalSd | MetalSd | None | 0.0 | TA | TA | CBlock | Gd | TA | Gd | ALQ | 978 | Unf | 0 | 284 | 1262 | GasA | ... | Y | SBrkr | 1262 | 0 | 0 | 1262 | 0 | 1 | 2 | 0 | 3 | 1 | TA | 6 | Typ | 1 | TA | Attchd | 1976.0 | RFn | 2 | 460 | TA | TA | Y | 298 | 0 | 0 | 0 | 0 | 0 | NaN | NaN | NaN | 0 | 5 | 2007 | WD | Normal | 181500 |
2 | 3 | 60 | RL | 68.0 | 11250 | Pave | NaN | IR1 | Lvl | AllPub | Inside | Gtl | CollgCr | Norm | Norm | 1Fam | 2Story | 7 | 5 | 2001 | 2002 | Gable | CompShg | VinylSd | VinylSd | BrkFace | 162.0 | Gd | TA | PConc | Gd | TA | Mn | GLQ | 486 | Unf | 0 | 434 | 920 | GasA | ... | Y | SBrkr | 920 | 866 | 0 | 1786 | 1 | 0 | 2 | 1 | 3 | 1 | Gd | 6 | Typ | 1 | TA | Attchd | 2001.0 | RFn | 2 | 608 | TA | TA | Y | 0 | 42 | 0 | 0 | 0 | 0 | NaN | NaN | NaN | 0 | 9 | 2008 | WD | Normal | 223500 |
3 | 4 | 70 | RL | 60.0 | 9550 | Pave | NaN | IR1 | Lvl | AllPub | Corner | Gtl | Crawfor | Norm | Norm | 1Fam | 2Story | 7 | 5 | 1915 | 1970 | Gable | CompShg | Wd Sdng | Wd Shng | None | 0.0 | TA | TA | BrkTil | TA | Gd | No | ALQ | 216 | Unf | 0 | 540 | 756 | GasA | ... | Y | SBrkr | 961 | 756 | 0 | 1717 | 1 | 0 | 1 | 0 | 3 | 1 | Gd | 7 | Typ | 1 | Gd | Detchd | 1998.0 | Unf | 3 | 642 | TA | TA | Y | 0 | 35 | 272 | 0 | 0 | 0 | NaN | NaN | NaN | 0 | 2 | 2006 | WD | Abnorml | 140000 |
4 | 5 | 60 | RL | 84.0 | 14260 | Pave | NaN | IR1 | Lvl | AllPub | FR2 | Gtl | NoRidge | Norm | Norm | 1Fam | 2Story | 8 | 5 | 2000 | 2000 | Gable | CompShg | VinylSd | VinylSd | BrkFace | 350.0 | Gd | TA | PConc | Gd | TA | Av | GLQ | 655 | Unf | 0 | 490 | 1145 | GasA | ... | Y | SBrkr | 1145 | 1053 | 0 | 2198 | 1 | 0 | 2 | 1 | 4 | 1 | Gd | 9 | Typ | 1 | TA | Attchd | 2000.0 | RFn | 3 | 836 | TA | TA | Y | 192 | 84 | 0 | 0 | 0 | 0 | NaN | NaN | NaN | 0 | 12 | 2008 | WD | Normal | 250000 |
となっており、数値データと文字データが混在していることが分かりました。 | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
次に、以降の処理を行いやすくするために、行名をIdに変更します。 |
train_df.set_index(keys='Id', inplace=True)
test_df.set_index(keys='Id', inplace=True)
train_df.head()
MSSubClass | MSZoning | LotFrontage | LotArea | Street | Alley | LotShape | LandContour | Utilities | LotConfig | LandSlope | Neighborhood | Condition1 | Condition2 | BldgType | HouseStyle | OverallQual | OverallCond | YearBuilt | YearRemodAdd | RoofStyle | RoofMatl | Exterior1st | Exterior2nd | MasVnrType | MasVnrArea | ExterQual | ExterCond | Foundation | BsmtQual | BsmtCond | BsmtExposure | BsmtFinType1 | BsmtFinSF1 | BsmtFinType2 | BsmtFinSF2 | BsmtUnfSF | TotalBsmtSF | Heating | HeatingQC | CentralAir | Electrical | 1stFlrSF | 2ndFlrSF | LowQualFinSF | GrLivArea | BsmtFullBath | BsmtHalfBath | FullBath | HalfBath | BedroomAbvGr | KitchenAbvGr | KitchenQual | TotRmsAbvGrd | Functional | Fireplaces | FireplaceQu | GarageType | GarageYrBlt | GarageFinish | GarageCars | GarageArea | GarageQual | GarageCond | PavedDrive | WoodDeckSF | OpenPorchSF | EnclosedPorch | 3SsnPorch | ScreenPorch | PoolArea | PoolQC | Fence | MiscFeature | MiscVal | MoSold | YrSold | SaleType | SaleCondition | SalePrice | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
Id | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
1 | 60 | RL | 65.0 | 8450 | Pave | NaN | Reg | Lvl | AllPub | Inside | Gtl | CollgCr | Norm | Norm | 1Fam | 2Story | 7 | 5 | 2003 | 2003 | Gable | CompShg | VinylSd | VinylSd | BrkFace | 196.0 | Gd | TA | PConc | Gd | TA | No | GLQ | 706 | Unf | 0 | 150 | 856 | GasA | Ex | Y | SBrkr | 856 | 854 | 0 | 1710 | 1 | 0 | 2 | 1 | 3 | 1 | Gd | 8 | Typ | 0 | NaN | Attchd | 2003.0 | RFn | 2 | 548 | TA | TA | Y | 0 | 61 | 0 | 0 | 0 | 0 | NaN | NaN | NaN | 0 | 2 | 2008 | WD | Normal | 208500 |
2 | 20 | RL | 80.0 | 9600 | Pave | NaN | Reg | Lvl | AllPub | FR2 | Gtl | Veenker | Feedr | Norm | 1Fam | 1Story | 6 | 8 | 1976 | 1976 | Gable | CompShg | MetalSd | MetalSd | None | 0.0 | TA | TA | CBlock | Gd | TA | Gd | ALQ | 978 | Unf | 0 | 284 | 1262 | GasA | Ex | Y | SBrkr | 1262 | 0 | 0 | 1262 | 0 | 1 | 2 | 0 | 3 | 1 | TA | 6 | Typ | 1 | TA | Attchd | 1976.0 | RFn | 2 | 460 | TA | TA | Y | 298 | 0 | 0 | 0 | 0 | 0 | NaN | NaN | NaN | 0 | 5 | 2007 | WD | Normal | 181500 |
3 | 60 | RL | 68.0 | 11250 | Pave | NaN | IR1 | Lvl | AllPub | Inside | Gtl | CollgCr | Norm | Norm | 1Fam | 2Story | 7 | 5 | 2001 | 2002 | Gable | CompShg | VinylSd | VinylSd | BrkFace | 162.0 | Gd | TA | PConc | Gd | TA | Mn | GLQ | 486 | Unf | 0 | 434 | 920 | GasA | Ex | Y | SBrkr | 920 | 866 | 0 | 1786 | 1 | 0 | 2 | 1 | 3 | 1 | Gd | 6 | Typ | 1 | TA | Attchd | 2001.0 | RFn | 2 | 608 | TA | TA | Y | 0 | 42 | 0 | 0 | 0 | 0 | NaN | NaN | NaN | 0 | 9 | 2008 | WD | Normal | 223500 |
4 | 70 | RL | 60.0 | 9550 | Pave | NaN | IR1 | Lvl | AllPub | Corner | Gtl | Crawfor | Norm | Norm | 1Fam | 2Story | 7 | 5 | 1915 | 1970 | Gable | CompShg | Wd Sdng | Wd Shng | None | 0.0 | TA | TA | BrkTil | TA | Gd | No | ALQ | 216 | Unf | 0 | 540 | 756 | GasA | Gd | Y | SBrkr | 961 | 756 | 0 | 1717 | 1 | 0 | 1 | 0 | 3 | 1 | Gd | 7 | Typ | 1 | Gd | Detchd | 1998.0 | Unf | 3 | 642 | TA | TA | Y | 0 | 35 | 272 | 0 | 0 | 0 | NaN | NaN | NaN | 0 | 2 | 2006 | WD | Abnorml | 140000 |
5 | 60 | RL | 84.0 | 14260 | Pave | NaN | IR1 | Lvl | AllPub | FR2 | Gtl | NoRidge | Norm | Norm | 1Fam | 2Story | 8 | 5 | 2000 | 2000 | Gable | CompShg | VinylSd | VinylSd | BrkFace | 350.0 | Gd | TA | PConc | Gd | TA | Av | GLQ | 655 | Unf | 0 | 490 | 1145 | GasA | Ex | Y | SBrkr | 1145 | 1053 | 0 | 2198 | 1 | 0 | 2 | 1 | 4 | 1 | Gd | 9 | Typ | 1 | TA | Attchd | 2000.0 | RFn | 3 | 836 | TA | TA | Y | 192 | 84 | 0 | 0 | 0 | 0 | NaN | NaN | NaN | 0 | 12 | 2008 | WD | Normal | 250000 |
今回の問題は回帰問題となるため、目的変数であるSalePrice が正規分布に近い方が精度が向上しやすいため、分布を確認します。 |
sns.distplot(train_df['SalePrice'], fit=norm)
mu, sigma = norm.fit(train_df['SalePrice'])
print('mean {:.2f} , sigma = {:.2f}'.format(mu, sigma))
plt.ylabel('Frequency')
fig = plt.figure()
res = stats.probplot(train_df['SalePrice'], plot=plt)
plt.show()
mean 180921.20 , sigma = 79415.29
少し分布に偏りがあることが分かりました。住宅価格は対数分布に近い形をとっているので、対数をとりその分布を確認します。
train_df['SalePrice'] = np.log1p(train_df['SalePrice'])
sns.distplot(train_df['SalePrice'], fit=norm)
mu, sigma = norm.fit(train_df['SalePrice'])
print('mean {:.2f} , sigma = {:.2f}'.format(mu, sigma))
plt.ylabel('Frequency')
fig = plt.figure()
res = stats.probplot(train_df['SalePrice'], plot=plt)
plt.show()
mean 12.02 , sigma = 0.40
かなり正規分布に近づきました。
データを目的変数と説明変数で分けます。
train_df_y = train_df['SalePrice']
train_df.drop(['SalePrice'], axis=1, inplace=True)
これから説明変数の処理を行っていきますが、学習データとテストデータで一括で処理できるように、データを結合します。
ntrain = train_df.shape[0]
all_data = pd.concat((train_df, test_df)).reset_index(drop=True)
データの欠損値を確認します。
all_data_na = (all_data.isnull().sum() / len(all_data)) * 100
all_data_na = all_data_na.drop(all_data_na[all_data_na == 0].index).sort_values(ascending=False)[:30]
missing_data = pd.DataFrame({'Missing Ratio' :all_data_na})
missing_data.head(22)
Missing Ratio | |
PoolQC | 99.725557 |
MiscFeature | 96.397942 |
Alley | 93.207547 |
Fence | 80.445969 |
FireplaceQu | 48.713551 |
LotFrontage | 16.672384 |
GarageFinish | 5.454545 |
GarageYrBlt | 5.454545 |
GarageQual | 5.454545 |
GarageCond | 5.454545 |
GarageType | 5.385935 |
BsmtExposure | 2.813036 |
BsmtCond | 2.813036 |
BsmtQual | 2.778731 |
BsmtFinType2 | 2.744425 |
BsmtFinType1 | 2.710120 |
MasVnrType | 0.823328 |
MasVnrArea | 0.789022 |
MSZoning | 0.137221 |
BsmtFullBath | 0.068611 |
BsmtHalfBath | 0.068611 |
Utilities | 0.068611 |
かなり欠損値があることがわかりました。
欠損情報を変換します
def na_change(all_data):
all_data["PoolQC"] = all_data["PoolQC"].fillna("None")
all_data["MiscFeature"] = all_data["MiscFeature"].fillna("None")
all_data["Alley"] = all_data["Alley"].fillna("None")
all_data["Fence"] = all_data["Fence"].fillna("None")
all_data["FireplaceQu"] = all_data["FireplaceQu"].fillna("None")
#Group by neighborhood and fill in missing value by the median LotFrontage of all the neighborhood
all_data["LotFrontage"] = all_data.groupby("Neighborhood")["LotFrontage"].transform(
lambda x: x.fillna(x.median()))
for col in ('GarageType', 'GarageFinish', 'GarageQual', 'GarageCond'):
all_data[col] = all_data[col].fillna('None')
for col in ('GarageYrBlt', 'GarageArea', 'GarageCars'):
all_data[col] = all_data[col].fillna(0)
for col in ('BsmtFinSF1', 'BsmtFinSF2', 'BsmtUnfSF','TotalBsmtSF', 'BsmtFullBath', 'BsmtHalfBath'):
all_data[col] = all_data[col].fillna(0)
for col in ('BsmtQual', 'BsmtCond', 'BsmtExposure', 'BsmtFinType1', 'BsmtFinType2'):
all_data[col] = all_data[col].fillna('None')
all_data["MasVnrType"] = all_data["MasVnrType"].fillna("None")
all_data["MasVnrArea"] = all_data["MasVnrArea"].fillna(0)
all_data['MSZoning'] = all_data['MSZoning'].fillna(all_data['MSZoning'].mode()[0])
all_data = all_data.drop(['Utilities'], axis=1)
all_data["Functional"] = all_data["Functional"].fillna("Typ")
all_data['Electrical'] = all_data['Electrical'].fillna(all_data['Electrical'].mode()[0])
all_data['KitchenQual'] = all_data['KitchenQual'].fillna(all_data['KitchenQual'].mode()[0])
all_data['Exterior1st'] = all_data['Exterior1st'].fillna(all_data['Exterior1st'].mode()[0])
all_data['Exterior2nd'] = all_data['Exterior2nd'].fillna(all_data['Exterior2nd'].mode()[0])
all_data['SaleType'] = all_data['SaleType'].fillna(all_data['SaleType'].mode()[0])
all_data['MSSubClass'] = all_data['MSSubClass'].fillna("None")
return all_data
all_data = na_change(all_data)
# 欠損値を確認
all_data_na = (all_data.isnull().sum() / len(all_data)) * 100
all_data_na = all_data_na.drop(all_data_na[all_data_na == 0].index).sort_values(ascending=False)[:30]
missing_data = pd.DataFrame({'Missing Ratio' :all_data_na})
missing_data.head(22)
Missing Ratio |
欠損値がなくなりました。 |
次に、文字情報をpandasのget_dummiesを利用してone-hot-encordingにより数値情報に変換します。 |
all_data = pd.get_dummies(all_data)
all_data.head()
MSSubClass | LotFrontage | LotArea | Street | Alley | LotShape | LandSlope | OverallQual | OverallCond | YearBuilt | YearRemodAdd | MasVnrArea | ExterQual | ExterCond | BsmtQual | BsmtCond | BsmtExposure | BsmtFinType1 | BsmtFinSF1 | BsmtFinType2 | BsmtFinSF2 | BsmtUnfSF | TotalBsmtSF | HeatingQC | CentralAir | 1stFlrSF | 2ndFlrSF | LowQualFinSF | GrLivArea | BsmtFullBath | BsmtHalfBath | FullBath | HalfBath | BedroomAbvGr | KitchenAbvGr | KitchenQual | TotRmsAbvGrd | Functional | Fireplaces | FireplaceQu | ... | Foundation_Stone | Foundation_Wood | Heating_Floor | Heating_GasA | Heating_GasW | Heating_Grav | Heating_OthW | Heating_Wall | Electrical_FuseA | Electrical_FuseF | Electrical_FuseP | Electrical_Mix | Electrical_SBrkr | GarageType_2Types | GarageType_Attchd | GarageType_Basment | GarageType_BuiltIn | GarageType_CarPort | GarageType_Detchd | GarageType_None | MiscFeature_Gar2 | MiscFeature_None | MiscFeature_Othr | MiscFeature_Shed | MiscFeature_TenC | SaleType_COD | SaleType_CWD | SaleType_Con | SaleType_ConLD | SaleType_ConLI | SaleType_ConLw | SaleType_New | SaleType_Oth | SaleType_WD | SaleCondition_Abnorml | SaleCondition_AdjLand | SaleCondition_Alloca | SaleCondition_Family | SaleCondition_Normal | SaleCondition_Partial | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 5 | 65.0 | 8450 | 1 | 1 | 3 | 0 | 7 | 4 | 2003 | 2003 | 196.0 | 2 | 4 | 2 | 4 | 3 | 2 | 706.0 | 6 | 0.0 | 150.0 | 856.0 | 0 | 1 | 856 | 854 | 0 | 1710 | 1.0 | 0.0 | 2 | 1 | 3 | 1 | 2 | 8 | 6 | 0 | 3 | ... | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 |
1 | 0 | 80.0 | 9600 | 1 | 1 | 3 | 0 | 6 | 7 | 1976 | 1976 | 0.0 | 3 | 4 | 2 | 4 | 1 | 0 | 978.0 | 6 | 0.0 | 284.0 | 1262.0 | 0 | 1 | 1262 | 0 | 0 | 1262 | 0.0 | 1.0 | 2 | 0 | 3 | 1 | 3 | 6 | 6 | 1 | 5 | ... | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 |
2 | 5 | 68.0 | 11250 | 1 | 1 | 0 | 0 | 7 | 4 | 2001 | 2002 | 162.0 | 2 | 4 | 2 | 4 | 2 | 2 | 486.0 | 6 | 0.0 | 434.0 | 920.0 | 0 | 1 | 920 | 866 | 0 | 1786 | 1.0 | 0.0 | 2 | 1 | 3 | 1 | 2 | 6 | 6 | 1 | 5 | ... | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 |
3 | 6 | 60.0 | 9550 | 1 | 1 | 0 | 0 | 7 | 4 | 1915 | 1970 | 0.0 | 3 | 4 | 4 | 1 | 3 | 0 | 216.0 | 6 | 0.0 | 540.0 | 756.0 | 2 | 1 | 961 | 756 | 0 | 1717 | 1.0 | 0.0 | 1 | 0 | 3 | 1 | 2 | 7 | 6 | 1 | 2 | ... | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 1 | 0 | 0 | 0 | 0 | 0 |
4 | 5 | 84.0 | 14260 | 1 | 1 | 0 | 0 | 8 | 4 | 2000 | 2000 | 350.0 | 2 | 4 | 2 | 4 | 0 | 2 | 655.0 | 6 | 0.0 | 490.0 | 1145.0 | 0 | 1 | 1145 | 1053 | 0 | 2198 | 1.0 | 0.0 | 2 | 1 | 4 | 1 | 2 | 9 | 6 | 1 | 5 | ... | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 0 | 1 | 0 | 0 | 0 | 0 | 1 | 0 |
文字情報がなくなり、すべて数値情報になりました。 | |||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
前処理が完了したので、教師データとテストデータに戻します。 |
train_df = all_data[:ntrain]
test_df = all_data[ntrain:]
print(train_df.shape, test_df.shape)
学習
学習と評価を行うため、トレーニングデータを分割します。
X = train_df.values
y = train_df_y.values
print(X.shape, y.shape)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.10, random_state=1)
print(X_train.shape, X_test.shape)
次に、モデルの作成を行います。今回はデータの標準化→モデルに適合
の流れを1つのパイプラインに集約させました。fitを呼び出すと学習を行い、predictを呼び出すと予測を行うクラスを作成しました。
class Pipeline():
def __init__(self, pretrainer, model):
self.stdsc = StandardScaler()
self.pretrainer = pretrainer
self.model = model
def fit(self, X, y, X_valid, y_valid):
# 標準化
self.X_std = self.stdsc.fit_transform(X)
self.X_valid_std = self.stdsc.transform(X_valid)
# 事前学習
pretrainer.fit(X_train=self.X_std,
eval_set=[self.X_valid_std],
max_epochs=1000,
patience=100)
# 学習
self.model.fit(X_train=self.X_std,
y_train=y_train,
eval_set=[(self.X_valid_std, y_valid)],
eval_name = ["valid"],
eval_metric={'rmse'},
max_epochs=5000,
from_unsupervised=pretrainer
)
return self
def predict(self, X):
self.X_std = self.stdsc.transform(X)
prediction = self.model.predict(self.X_std)
return np.array(prediction)
次に、TabNetを定義し学習を行っていきます。tabnet_params
でパラメータを定義しpretrainer = TabNetPretrainer(**tabnet_params)
とmodel = TabNetRegressor(**tabnet_params)
でモデルを定義します。パラメータの内容は[ここ]を参照してください。またTabNetはモデルによって精度にかなりばらつきがあったため、複数回モデル作成を行い、最も良かったものを最終的なモデルとしました。
tabnet_params = dict(n_d=8, n_a=8, n_steps=5, gamma=1,
n_independent=2, n_shared=2,
seed=1998, lambda_sparse=1e-3,
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=2e-2),
mask_type="entmax",
scheduler_params=dict(mode="min",
patience=5,
min_lr=1e-5,
factor=0.9,),
scheduler_fn=torch.optim.lr_scheduler.ReduceLROnPlateau,
verbose=50
)
rmse_best = np.inf
for i in range(10):
print('######### {}番目 ##############'.format(i+1))
#モデルの定義
pretrainer = TabNetPretrainer(**tabnet_params)
model = TabNetRegressor(**tabnet_params)
# パイプライン作成
pipeline = Pipeline(pretrainer, model)
# 評価データと学習データに分割
X_train, X_valid, y_train, y_valid = train_test_split(X_train, y_train, test_size=0.10, random_state=1)
y_train = y_train.reshape(-1, 1)
y_valid = y_valid.reshape(-1, 1)
# 学習
pipeline.fit(X_train, y_train, X_valid, y_valid)
# テストデータの予測
y_pred = pipeline.predict(X_test)
# RMSEを計算し、これまでのモデルの中で最高で合ったらモデルを保存
rmse = np.sqrt(mean_squared_error(y_test, y_pred))
print('{}番目のRMSE:{}'.format(i+1, rmse))
if rmse < rmse_best:
rmse_best = rmse
model_best = pipeline
######### 1番目 ##############
Device used : cuda
Device used : cuda
epoch 0 | loss: 112869120.75997| val_0_unsup_loss: 8152667.0| 0:00:00s
epoch 50 | loss: 84455.6602| val_0_unsup_loss: 842184.0625| 0:00:12s
epoch 100| loss: 51505.95608| val_0_unsup_loss: 301123.28125| 0:00:25s
epoch 150| loss: 8935.98715| val_0_unsup_loss: 25950.01758| 0:00:37s
epoch 200| loss: 491.73772| val_0_unsup_loss: 1192.3595| 0:00:49s
epoch 250| loss: 74.34149| val_0_unsup_loss: 336.15506| 0:01:01s
epoch 300| loss: 31.98659| val_0_unsup_loss: 139.42387| 0:01:13s
epoch 350| loss: 1666.08146| val_0_unsup_loss: 152.79266| 0:01:25s
epoch 400| loss: 3855.69143| val_0_unsup_loss: 38.82077| 0:01:37s
epoch 450| loss: 21.85563| val_0_unsup_loss: 19.67164| 0:01:49s
epoch 500| loss: 8.56555 | val_0_unsup_loss: 18.34324| 0:02:01s
epoch 550| loss: 5.18283 | val_0_unsup_loss: 14.3733 | 0:02:13s
epoch 600| loss: 6.82209 | val_0_unsup_loss: 15.61955| 0:02:25s
Early stopping occurred at epoch 641 with best_epoch = 541 and best_val_0_unsup_loss = 12.81931
Best weights from best epoch are automatically used!
Loading weights from unsupervised pretraining
epoch 0 | loss: 147.09146| valid_rmse: 12.06636| 0:00:00s
epoch 50 | loss: 1.10993 | valid_rmse: 1.30975 | 0:00:10s
Early stopping occurred at epoch 74 with best_epoch = 64 and best_valid_rmse = 0.56005
Best weights from best epoch are automatically used!
1番目のRMSE:0.5732574974545894
######### 2番目 ##############
Device used : cuda
Device used : cuda
epoch 0 | loss: 56876420.88596| val_0_unsup_loss: 18685788.0| 0:00:00s
epoch 50 | loss: 148481.84402| val_0_unsup_loss: 892821.3125| 0:00:12s
epoch 100| loss: 5002.76485| val_0_unsup_loss: 58444.03516| 0:00:24s
epoch 150| loss: 2183.35928| val_0_unsup_loss: 9718.08008| 0:00:35s
epoch 200| loss: 224.419 | val_0_unsup_loss: 1830.16101| 0:00:47s
epoch 250| loss: 14187.41747| val_0_unsup_loss: 701.71118| 0:01:01s
epoch 300| loss: 6390.98186| val_0_unsup_loss: 150.8916| 0:01:14s
epoch 350| loss: 12570.42858| val_0_unsup_loss: 81.2753 | 0:01:26s
epoch 400| loss: 25.52996| val_0_unsup_loss: 106.70511| 0:01:37s
epoch 450| loss: 7195.31966| val_0_unsup_loss: 77.783 | 0:01:48s
epoch 500| loss: 17.96937| val_0_unsup_loss: 54.49159| 0:02:00s
epoch 550| loss: 13.97167| val_0_unsup_loss: 75.29209| 0:02:11s
epoch 600| loss: 20.92112| val_0_unsup_loss: 53.34031| 0:02:22s
Early stopping occurred at epoch 609 with best_epoch = 509 and best_val_0_unsup_loss = 23.04633
Best weights from best epoch are automatically used!
Loading weights from unsupervised pretraining
epoch 0 | loss: 138.89917| valid_rmse: 11.74293| 0:00:00s
epoch 50 | loss: 0.58857 | valid_rmse: 0.92686 | 0:00:09s
epoch 100| loss: 0.11881 | valid_rmse: 0.35115 | 0:00:18s
Early stopping occurred at epoch 118 with best_epoch = 108 and best_valid_rmse = 0.25378
Best weights from best epoch are automatically used!
2番目のRMSE:0.21336961298887985
######### 3番目 ##############
Device used : cuda
Device used : cuda
epoch 0 | loss: 53047928.0| val_0_unsup_loss: 16785738.0| 0:00:00s
epoch 50 | loss: 127546.4375| val_0_unsup_loss: 128776056.0| 0:00:08s
epoch 100| loss: 7766.67334| val_0_unsup_loss: 77044832.0| 0:00:16s
Early stopping occurred at epoch 107 with best_epoch = 7 and best_val_0_unsup_loss = 8168445.0
Best weights from best epoch are automatically used!
Loading weights from unsupervised pretraining
epoch 0 | loss: 142.56097| valid_rmse: 11.98251| 0:00:00s
epoch 50 | loss: 0.46553 | valid_rmse: 1.26716 | 0:00:07s
epoch 100| loss: 0.05202 | valid_rmse: 0.2648 | 0:00:13s
Early stopping occurred at epoch 121 with best_epoch = 111 and best_valid_rmse = 0.24816
Best weights from best epoch are automatically used!
3番目のRMSE:0.6404878648056896
######### 4番目 ##############
Device used : cuda
Device used : cuda
epoch 0 | loss: 36778532.0| val_0_unsup_loss: 11881544.0| 0:00:00s
epoch 50 | loss: 301578.96875| val_0_unsup_loss: 361585120.0| 0:00:07s
epoch 100| loss: 9690.23145| val_0_unsup_loss: 12532827.0| 0:00:15s
Early stopping occurred at epoch 109 with best_epoch = 9 and best_val_0_unsup_loss = 4209855.5
Best weights from best epoch are automatically used!
Loading weights from unsupervised pretraining
epoch 0 | loss: 119.93172| valid_rmse: 11.93541| 0:00:00s
epoch 50 | loss: 0.26538 | valid_rmse: 0.81614 | 0:00:06s
Early stopping occurred at epoch 89 with best_epoch = 79 and best_valid_rmse = 0.26959
Best weights from best epoch are automatically used!
4番目のRMSE:0.3138825021882866
######### 5番目 ##############
Device used : cuda
Device used : cuda
epoch 0 | loss: 98724592.0| val_0_unsup_loss: 18642930.0| 0:00:00s
epoch 50 | loss: 207865.03125| val_0_unsup_loss: 10883383.0| 0:00:07s
epoch 100| loss: 2071.20288| val_0_unsup_loss: 4162254.5| 0:00:15s
epoch 150| loss: 178.09682| val_0_unsup_loss: 1273017.375| 0:00:24s
epoch 200| loss: 12.41261| val_0_unsup_loss: 1266617.75| 0:00:31s
epoch 250| loss: 5.4191 | val_0_unsup_loss: 1168223.75| 0:00:39s
epoch 300| loss: 71.96937| val_0_unsup_loss: 1217691.375| 0:00:46s
Early stopping occurred at epoch 325 with best_epoch = 225 and best_val_0_unsup_loss = 1060690.375
Best weights from best epoch are automatically used!
Loading weights from unsupervised pretraining
epoch 0 | loss: 144.38873| valid_rmse: 12.00936| 0:00:00s
Early stopping occurred at epoch 31 with best_epoch = 21 and best_valid_rmse = 7.15481
Best weights from best epoch are automatically used!
5番目のRMSE:7.245829108461303
######### 6番目 ##############
Device used : cuda
Device used : cuda
epoch 0 | loss: 168969248.0| val_0_unsup_loss: 19196760.0| 0:00:00s
epoch 50 | loss: 226875.89062| val_0_unsup_loss: 3304720.75| 0:00:07s
epoch 100| loss: 9766.67383| val_0_unsup_loss: 2984239.75| 0:00:14s
epoch 150| loss: 1296.66162| val_0_unsup_loss: 3050195.5| 0:00:21s
Early stopping occurred at epoch 164 with best_epoch = 64 and best_val_0_unsup_loss = 2317795.75
Best weights from best epoch are automatically used!
Loading weights from unsupervised pretraining
epoch 0 | loss: 145.52234| valid_rmse: 11.98945| 0:00:00s
epoch 50 | loss: 0.59494 | valid_rmse: 1.2914 | 0:00:06s
epoch 100| loss: 0.06902 | valid_rmse: 0.33971 | 0:00:11s
Early stopping occurred at epoch 106 with best_epoch = 96 and best_valid_rmse = 0.31578
Best weights from best epoch are automatically used!
6番目のRMSE:0.3511308647629572
######### 7番目 ##############
Device used : cuda
Device used : cuda
epoch 0 | loss: 140554368.0| val_0_unsup_loss: 25966662.0| 0:00:00s
epoch 50 | loss: 57978.83594| val_0_unsup_loss: 1957874.25| 0:00:07s
epoch 100| loss: 40.81237| val_0_unsup_loss: 1858593.25| 0:00:14s
epoch 150| loss: 0.87732 | val_0_unsup_loss: 1858593.375| 0:00:20s
Early stopping occurred at epoch 165 with best_epoch = 65 and best_val_0_unsup_loss = 1858593.25
Best weights from best epoch are automatically used!
Loading weights from unsupervised pretraining
epoch 0 | loss: 144.86298| valid_rmse: 12.04391| 0:00:00s
Early stopping occurred at epoch 39 with best_epoch = 29 and best_valid_rmse = 6.68239
Best weights from best epoch are automatically used!
7番目のRMSE:7.947661647717846
######### 8番目 ##############
Device used : cuda
Device used : cuda
epoch 0 | loss: 181639712.0| val_0_unsup_loss: 20568860.0| 0:00:00s
epoch 50 | loss: 341412.625| val_0_unsup_loss: 4480100.0| 0:00:06s
epoch 100| loss: 67.57111| val_0_unsup_loss: 1612525.25| 0:00:14s
epoch 150| loss: 6.24322 | val_0_unsup_loss: 1612525.25| 0:00:21s
epoch 200| loss: 0.96622 | val_0_unsup_loss: 1612525.25| 0:00:27s
Early stopping occurred at epoch 200 with best_epoch = 100 and best_val_0_unsup_loss = 1612525.25
Best weights from best epoch are automatically used!
Loading weights from unsupervised pretraining
epoch 0 | loss: 144.55081| valid_rmse: 12.05962| 0:00:00s
Early stopping occurred at epoch 49 with best_epoch = 39 and best_valid_rmse = 4.02068
Best weights from best epoch are automatically used!
8番目のRMSE:4.088560582540583
######### 9番目 ##############
Device used : cuda
Device used : cuda
epoch 0 | loss: 244117152.0| val_0_unsup_loss: 63549964.0| 0:00:00s
epoch 50 | loss: 596912.375| val_0_unsup_loss: 5337391.0| 0:00:06s
epoch 100| loss: 3315.87476| val_0_unsup_loss: 41764384.0| 0:00:12s
Early stopping occurred at epoch 148 with best_epoch = 48 and best_val_0_unsup_loss = 4416323.0
Best weights from best epoch are automatically used!
Loading weights from unsupervised pretraining
epoch 0 | loss: 146.29584| valid_rmse: 12.06835| 0:00:00s
epoch 50 | loss: 2.03836 | valid_rmse: 2.66132 | 0:00:05s
epoch 100| loss: 0.06829 | valid_rmse: 0.49171 | 0:00:10s
Early stopping occurred at epoch 130 with best_epoch = 120 and best_valid_rmse = 0.37411
Best weights from best epoch are automatically used!
9番目のRMSE:0.2490394834684667
######### 10番目 ##############
Device used : cuda
Device used : cuda
epoch 0 | loss: 276483552.0| val_0_unsup_loss: 20401114.0| 0:00:00s
epoch 50 | loss: 379183.8125| val_0_unsup_loss: 3023311.25| 0:00:06s
epoch 100| loss: 1170.36536| val_0_unsup_loss: 2887532.0| 0:00:13s
epoch 150| loss: 0.87989 | val_0_unsup_loss: 2887532.0| 0:00:19s
Early stopping occurred at epoch 187 with best_epoch = 87 and best_val_0_unsup_loss = 2887265.25
Best weights from best epoch are automatically used!
Loading weights from unsupervised pretraining
epoch 0 | loss: 144.31578| valid_rmse: 12.07683| 0:00:00s
epoch 50 | loss: 7.20287 | valid_rmse: 4.75794 | 0:00:05s
epoch 100| loss: 0.12646 | valid_rmse: 0.51692 | 0:00:11s
epoch 150| loss: 0.05806 | valid_rmse: 0.24796 | 0:00:15s
Early stopping occurred at epoch 150 with best_epoch = 140 and best_valid_rmse = 0.23579
Best weights from best epoch are automatically used!
10番目のRMSE:0.4465059049988976
損失関数の軌跡と評価データのRMSEを見て、学習が行えているか確認します。
plt.plot(model_best.model.history['loss'])
plt.plot(model_best.model.history['valid_rmse'])
損失関数、RMSEが学習と共に減少していたので、うまく学習が行えていることが分かりました。
学習が行えたので、テストデータの予測を行い、横軸に正解、縦軸予測値をプロット。適合度合いを可視化しました。
y_pred = model_best.predict(X_test)
y_pred_train = model_best.predict(X_train)
rmse = np.sqrt(mean_squared_error(y_test, y_pred))
print(rmse)
mae = mean_absolute_error(y_test, y_pred)
print(mae)
plt.plot([10, 14], [10, 14], '-', color='#00000033', markersize=10)
plt.plot(y_train, y_pred_train, 'ob', markersize = 4, label = 'train')
plt.plot(y_test, y_pred, 'or', markersize = 4, label='test')
plt.xlim(10.3, 13.8)
plt.ylim(10.3, 13.8)
plt.legend()
plt.xlabel('answer')
plt.ylabel('predict')
plt.show()
次に、特徴量の重要度を取得し可視化しました。
y_pred = model_best.model.predict(X_test)
model.feature_importances_
feat_importances = pd.Series(model.feature_importances_, index=train_df.columns)
feat_importances.nlargest(20).plot(kind='barh')
GarageType_Attchd
やExterior2nd_VinySd
が重要であることが分かりました。
提出
最後に、提出データの予測を行い、提出用データを作成します。
#提出用ファイル作成
X_submit = test_df.values
y_submit= model_best.predict(X_submit)
# 元の値に戻す
y_submit = np.expm1(y_submit)
# 予測値を追加
submission_df = pd.read_csv('./sample_submission.csv')
submission_df['SalePrice'] = pd.DataFrame(y_submit)
# csvファイとして保存
submission_df.to_csv("submit.csv", index=False)
submit.csvをKaggleに提出することで、順位がわかります。
まとめ
今回はTabNetを使用してKaggleの住宅価格予測に挑戦しました。
TabNetを回帰問題に利用することを目的としたため、ハイパーパラメータの調整などは行いませんでした。チューニングを行えばもっと良い結果が得られると思います。