tldr
KggleのCOVID-19 Healthy Diet DatasetをPredicting COVID-19 Mortality - Data Every Day #048に沿ってやっていきます。
実行環境はGoogle Colaboratorです。
インポート
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import sklearn.preprocessing as sp
from sklearn.model_selection import train_test_split
import sklearn.linear_model as slm
import tensorflow as tf
データのダウンロード
Google Driveをマウントします。
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
KaggleのAPIクライアントを初期化し、認証します。
認証情報はGoogle Drive内(/content/drive/My Drive/Colab Notebooks/Kaggle
)にkaggle.json
として置いてあります。
import os
kaggle_path = "/content/drive/My Drive/Colab Notebooks/Kaggle"
os.environ['KAGGLE_CONFIG_DIR'] = kaggle_path
from kaggle.api.kaggle_api_extended import KaggleApi
api = KaggleApi()
api.authenticate()
Kaggle APIを使ってデータをダウンロードします。
dataset_id = 'mariaren/covid19-healthy-diet-dataset'
dataset = api.dataset_list_files(dataset_id)
file_name = dataset.files[0].name
file_path = os.path.join(api.get_default_download_dir(), file_name)
file_path
'/content/Food_Supply_Quantity_kg_Data.csv'
api.dataset_download_file(dataset_id, file_name, force=True, quiet=False)
100%|██████████| 43.4k/43.4k [00:00<00:00, 22.3MB/s]
Downloading Food_Supply_Quantity_kg_Data.csv to /content
True
データの読み込み
data = pd.read_csv(file_path)
data
Country | Alcoholic Beverages | Animal fats | Animal Products | Aquatic Products, Other | Cereals - Excluding Beer | Eggs | Fish, Seafood | Fruits - Excluding Wine | Meat | Milk - Excluding Butter | Miscellaneous | Offals | Oilcrops | Pulses | Spices | Starchy Roots | Stimulants | Sugar & Sweeteners | Sugar Crops | Treenuts | Vegetable Oils | Vegetables | Vegetal Products | Obesity | Undernourished | Confirmed | Deaths | Recovered | Active | Population | Unit (all except Population) | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | Afghanistan | 0.0014 | 0.1973 | 9.4341 | 0.0000 | 24.8097 | 0.2099 | 0.0350 | 5.3495 | 1.2020 | 7.5828 | 0.0728 | 0.2057 | 0.0700 | 0.2953 | 0.0574 | 0.8802 | 0.3078 | 1.3489 | 0.000 | 0.0770 | 0.5345 | 6.7642 | 40.5645 | 4.5 | 29.8 | 0.125149 | 0.005058 | 0.098263 | 0.021827 | 38928000.0 | % |
1 | Albania | 1.6719 | 0.1357 | 18.7684 | 0.0000 | 5.7817 | 0.5815 | 0.2126 | 6.7861 | 1.8845 | 15.7213 | 0.1123 | 0.2324 | 0.9377 | 0.2380 | 0.0008 | 1.8096 | 0.1055 | 1.5367 | 0.000 | 0.1515 | 0.3261 | 11.7753 | 31.2304 | 22.3 | 6.2 | 1.733298 | 0.035800 | 0.874560 | 0.822939 | 2838000.0 | % |
2 | Algeria | 0.2711 | 0.0282 | 9.6334 | 0.0000 | 13.6816 | 0.5277 | 0.2416 | 6.3801 | 1.1305 | 7.6189 | 0.1671 | 0.0870 | 0.3493 | 0.4783 | 0.0557 | 4.1340 | 0.2216 | 1.8342 | 0.000 | 0.1152 | 1.0310 | 11.6484 | 40.3651 | 26.6 | 3.9 | 0.208754 | 0.005882 | 0.137268 | 0.065604 | 44357000.0 | % |
3 | Angola | 5.8087 | 0.0560 | 4.9278 | 0.0000 | 9.1085 | 0.0587 | 1.7707 | 6.0005 | 2.0571 | 0.8311 | 0.1165 | 0.1550 | 0.4186 | 0.6507 | 0.0009 | 18.1102 | 0.0508 | 1.8495 | 0.000 | 0.0061 | 0.6463 | 2.3041 | 45.0722 | 6.8 | 25 | 0.050049 | 0.001144 | 0.027440 | 0.021465 | 32522000.0 | % |
4 | Antigua and Barbuda | 3.5764 | 0.0087 | 16.6613 | 0.0000 | 5.9960 | 0.2274 | 4.1489 | 10.7451 | 5.6888 | 6.3663 | 0.7139 | 0.2219 | 0.2172 | 0.1840 | 0.1524 | 1.4522 | 0.1564 | 3.8749 | 0.000 | 0.0253 | 0.8102 | 5.4495 | 33.3233 | 19.1 | NaN | 0.151020 | 0.005102 | 0.140816 | 0.005102 | 98000.0 | % |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
165 | Venezuela (Bolivarian Republic of) | 2.5952 | 0.0403 | 14.7565 | 0.0000 | 12.9253 | 0.3389 | 0.9456 | 7.6460 | 3.8328 | 9.3920 | 0.0702 | 0.2078 | 0.0281 | 0.3342 | 0.0009 | 2.5643 | 0.1479 | 3.4106 | 0.000 | 0.0009 | 1.3734 | 4.1474 | 35.2416 | 25.2 | 21.2 | 0.377466 | 0.003351 | 0.359703 | 0.014411 | 28645000.0 | % |
166 | Vietnam | 1.4591 | 0.1640 | 8.5765 | 0.0042 | 16.8740 | 0.3077 | 2.6392 | 5.9029 | 4.4382 | 0.6069 | 0.0126 | 0.4149 | 0.8410 | 0.2032 | 0.2074 | 1.0596 | 0.2880 | 1.2846 | 0.815 | 0.3070 | 0.2201 | 11.9508 | 41.4232 | 2.1 | 9.3 | 0.001457 | 0.000036 | 0.001295 | 0.000126 | 96209000.0 | % |
167 | Yemen | 0.0364 | 0.0446 | 5.7874 | 0.0000 | 27.2077 | 0.2579 | 0.5240 | 5.1344 | 2.7871 | 1.8911 | 0.2033 | 0.2827 | 0.0893 | 0.8645 | 0.0347 | 1.0794 | 0.2199 | 5.0468 | 0.000 | 0.0017 | 1.0811 | 3.2135 | 44.2126 | 14.1 | 38.9 | 0.006987 | 0.002032 | 0.004640 | 0.000315 | 29826000.0 | % |
168 | Zambia | 5.7360 | 0.0829 | 6.0197 | 0.0000 | 21.1938 | 0.3399 | 1.6924 | 1.0183 | 1.8427 | 1.7570 | 0.2149 | 0.3048 | 1.8736 | 0.1756 | 0.0478 | 7.9649 | 0.0618 | 1.5632 | 0.000 | 0.0014 | 0.6657 | 3.4649 | 43.9789 | 6.5 | 46.7 | 0.099663 | 0.001996 | 0.094696 | 0.002970 | 18384000.0 | % |
169 | Zimbabwe | 4.0552 | 0.0755 | 8.1489 | 0.0000 | 22.6240 | 0.2678 | 0.5518 | 2.2000 | 2.6142 | 4.4310 | 0.2012 | 0.2086 | 0.4498 | 0.4261 | 0.0252 | 2.9870 | 0.1494 | 4.6485 | 0.000 | 0.0518 | 1.7103 | 2.3213 | 41.8526 | 12.3 | 51.3 | 0.076418 | 0.002079 | 0.064280 | 0.010059 | 14863000.0 | % |
170 rows × 32 columns
下準備
data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 170 entries, 0 to 169
Data columns (total 32 columns):
# Column Non-Null Count Dtype
--- ------ -------------- -----
0 Country 170 non-null object
1 Alcoholic Beverages 170 non-null float64
2 Animal fats 170 non-null float64
3 Animal Products 170 non-null float64
4 Aquatic Products, Other 170 non-null float64
5 Cereals - Excluding Beer 170 non-null float64
6 Eggs 170 non-null float64
7 Fish, Seafood 170 non-null float64
8 Fruits - Excluding Wine 170 non-null float64
9 Meat 170 non-null float64
10 Milk - Excluding Butter 170 non-null float64
11 Miscellaneous 170 non-null float64
12 Offals 170 non-null float64
13 Oilcrops 170 non-null float64
14 Pulses 170 non-null float64
15 Spices 170 non-null float64
16 Starchy Roots 170 non-null float64
17 Stimulants 170 non-null float64
18 Sugar & Sweeteners 170 non-null float64
19 Sugar Crops 170 non-null float64
20 Treenuts 170 non-null float64
21 Vegetable Oils 170 non-null float64
22 Vegetables 170 non-null float64
23 Vegetal Products 170 non-null float64
24 Obesity 167 non-null float64
25 Undernourished 163 non-null object
26 Confirmed 164 non-null float64
27 Deaths 164 non-null float64
28 Recovered 164 non-null float64
29 Active 162 non-null float64
30 Population 170 non-null float64
31 Unit (all except Population) 170 non-null object
dtypes: float64(29), object(3)
memory usage: 42.6+ KB
data.isna().sum()
Country 0
Alcoholic Beverages 0
Animal fats 0
Animal Products 0
Aquatic Products, Other 0
Cereals - Excluding Beer 0
Eggs 0
Fish, Seafood 0
Fruits - Excluding Wine 0
Meat 0
Milk - Excluding Butter 0
Miscellaneous 0
Offals 0
Oilcrops 0
Pulses 0
Spices 0
Starchy Roots 0
Stimulants 0
Sugar & Sweeteners 0
Sugar Crops 0
Treenuts 0
Vegetable Oils 0
Vegetables 0
Vegetal Products 0
Obesity 3
Undernourished 7
Confirmed 6
Deaths 6
Recovered 6
Active 8
Population 0
Unit (all except Population) 0
dtype: int64
data = data.drop('Unit (all except Population)', axis=1)
for column in data.columns:
if data.dtypes[column] != 'object':
data[column] = data[column].fillna(data[column].mean())
data.isna().sum()
Country 0
Alcoholic Beverages 0
Animal fats 0
Animal Products 0
Aquatic Products, Other 0
Cereals - Excluding Beer 0
Eggs 0
Fish, Seafood 0
Fruits - Excluding Wine 0
Meat 0
Milk - Excluding Butter 0
Miscellaneous 0
Offals 0
Oilcrops 0
Pulses 0
Spices 0
Starchy Roots 0
Stimulants 0
Sugar & Sweeteners 0
Sugar Crops 0
Treenuts 0
Vegetable Oils 0
Vegetables 0
Vegetal Products 0
Obesity 0
Undernourished 7
Confirmed 0
Deaths 0
Recovered 0
Active 0
Population 0
dtype: int64
undernourished_numeric = data.loc[data['Undernourished'] != '<2.5', 'Undernourished'].astype(np.float)
undernourished_numeric
0 29.8
1 6.2
2 3.9
3 25.0
4 NaN
...
165 21.2
166 9.3
167 38.9
168 46.7
169 51.3
Name: Undernourished, Length: 126, dtype: float64
undernourished_numeric = undernourished_numeric.fillna(undernourished_numeric.mean())
undernourished_numeric = pd.qcut(undernourished_numeric, q=3, labels=[1, 2, 3])
undernourished_numeric
0 3
1 1
2 1
3 3
4 2
..
165 3
166 2
167 3
168 3
169 3
Name: Undernourished, Length: 126, dtype: category
Categories (3, int64): [1 < 2 < 3]
data.loc[undernourished_numeric.index, 'Undernourished'] = undernourished_numeric
data['Undernourished'] = data['Undernourished'].apply(lambda x: 0 if x == '<2.5' else x)
data['Undernourished'].value_counts()
0 44
3 42
2 42
1 42
Name: Undernourished, dtype: int64
data
Country | Alcoholic Beverages | Animal fats | Animal Products | Aquatic Products, Other | Cereals - Excluding Beer | Eggs | Fish, Seafood | Fruits - Excluding Wine | Meat | Milk - Excluding Butter | Miscellaneous | Offals | Oilcrops | Pulses | Spices | Starchy Roots | Stimulants | Sugar & Sweeteners | Sugar Crops | Treenuts | Vegetable Oils | Vegetables | Vegetal Products | Obesity | Undernourished | Confirmed | Deaths | Recovered | Active | Population | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | Afghanistan | 0.0014 | 0.1973 | 9.4341 | 0.0000 | 24.8097 | 0.2099 | 0.0350 | 5.3495 | 1.2020 | 7.5828 | 0.0728 | 0.2057 | 0.0700 | 0.2953 | 0.0574 | 0.8802 | 0.3078 | 1.3489 | 0.000 | 0.0770 | 0.5345 | 6.7642 | 40.5645 | 4.5 | 3 | 0.125149 | 0.005058 | 0.098263 | 0.021827 | 38928000.0 |
1 | Albania | 1.6719 | 0.1357 | 18.7684 | 0.0000 | 5.7817 | 0.5815 | 0.2126 | 6.7861 | 1.8845 | 15.7213 | 0.1123 | 0.2324 | 0.9377 | 0.2380 | 0.0008 | 1.8096 | 0.1055 | 1.5367 | 0.000 | 0.1515 | 0.3261 | 11.7753 | 31.2304 | 22.3 | 1 | 1.733298 | 0.035800 | 0.874560 | 0.822939 | 2838000.0 |
2 | Algeria | 0.2711 | 0.0282 | 9.6334 | 0.0000 | 13.6816 | 0.5277 | 0.2416 | 6.3801 | 1.1305 | 7.6189 | 0.1671 | 0.0870 | 0.3493 | 0.4783 | 0.0557 | 4.1340 | 0.2216 | 1.8342 | 0.000 | 0.1152 | 1.0310 | 11.6484 | 40.3651 | 26.6 | 1 | 0.208754 | 0.005882 | 0.137268 | 0.065604 | 44357000.0 |
3 | Angola | 5.8087 | 0.0560 | 4.9278 | 0.0000 | 9.1085 | 0.0587 | 1.7707 | 6.0005 | 2.0571 | 0.8311 | 0.1165 | 0.1550 | 0.4186 | 0.6507 | 0.0009 | 18.1102 | 0.0508 | 1.8495 | 0.000 | 0.0061 | 0.6463 | 2.3041 | 45.0722 | 6.8 | 3 | 0.050049 | 0.001144 | 0.027440 | 0.021465 | 32522000.0 |
4 | Antigua and Barbuda | 3.5764 | 0.0087 | 16.6613 | 0.0000 | 5.9960 | 0.2274 | 4.1489 | 10.7451 | 5.6888 | 6.3663 | 0.7139 | 0.2219 | 0.2172 | 0.1840 | 0.1524 | 1.4522 | 0.1564 | 3.8749 | 0.000 | 0.0253 | 0.8102 | 5.4495 | 33.3233 | 19.1 | 2 | 0.151020 | 0.005102 | 0.140816 | 0.005102 | 98000.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
165 | Venezuela (Bolivarian Republic of) | 2.5952 | 0.0403 | 14.7565 | 0.0000 | 12.9253 | 0.3389 | 0.9456 | 7.6460 | 3.8328 | 9.3920 | 0.0702 | 0.2078 | 0.0281 | 0.3342 | 0.0009 | 2.5643 | 0.1479 | 3.4106 | 0.000 | 0.0009 | 1.3734 | 4.1474 | 35.2416 | 25.2 | 3 | 0.377466 | 0.003351 | 0.359703 | 0.014411 | 28645000.0 |
166 | Vietnam | 1.4591 | 0.1640 | 8.5765 | 0.0042 | 16.8740 | 0.3077 | 2.6392 | 5.9029 | 4.4382 | 0.6069 | 0.0126 | 0.4149 | 0.8410 | 0.2032 | 0.2074 | 1.0596 | 0.2880 | 1.2846 | 0.815 | 0.3070 | 0.2201 | 11.9508 | 41.4232 | 2.1 | 2 | 0.001457 | 0.000036 | 0.001295 | 0.000126 | 96209000.0 |
167 | Yemen | 0.0364 | 0.0446 | 5.7874 | 0.0000 | 27.2077 | 0.2579 | 0.5240 | 5.1344 | 2.7871 | 1.8911 | 0.2033 | 0.2827 | 0.0893 | 0.8645 | 0.0347 | 1.0794 | 0.2199 | 5.0468 | 0.000 | 0.0017 | 1.0811 | 3.2135 | 44.2126 | 14.1 | 3 | 0.006987 | 0.002032 | 0.004640 | 0.000315 | 29826000.0 |
168 | Zambia | 5.7360 | 0.0829 | 6.0197 | 0.0000 | 21.1938 | 0.3399 | 1.6924 | 1.0183 | 1.8427 | 1.7570 | 0.2149 | 0.3048 | 1.8736 | 0.1756 | 0.0478 | 7.9649 | 0.0618 | 1.5632 | 0.000 | 0.0014 | 0.6657 | 3.4649 | 43.9789 | 6.5 | 3 | 0.099663 | 0.001996 | 0.094696 | 0.002970 | 18384000.0 |
169 | Zimbabwe | 4.0552 | 0.0755 | 8.1489 | 0.0000 | 22.6240 | 0.2678 | 0.5518 | 2.2000 | 2.6142 | 4.4310 | 0.2012 | 0.2086 | 0.4498 | 0.4261 | 0.0252 | 2.9870 | 0.1494 | 4.6485 | 0.000 | 0.0518 | 1.7103 | 2.3213 | 41.8526 | 12.3 | 3 | 0.076418 | 0.002079 | 0.064280 | 0.010059 | 14863000.0 |
170 rows × 31 columns
フィーチャーとターゲットの選択
data = data.drop('Country', axis=1)
data = data.drop(['Confirmed', 'Recovered', 'Active'], axis=1)
pd.qcut(data['Deaths'], q=2, labels=[0, 1]).value_counts()
1 85
0 85
Name: Deaths, dtype: int64
data['Death'] = pd.qcut(data['Deaths'], q=2, labels=[0, 1])
分割とスケーリング
y = data['Death']
X = data.drop('Death', axis=1)
scaler = sp.StandardScaler()
X = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8)
トレーニング
X.shape
(170, 27)
model = tf.keras.Sequential([
tf.keras.layers.Dense(64, activation='relu', input_shape=(27,)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(1, activation='sigmoid'),
])
model.summary()
model.compile(
optimizer='adam',
loss='binary_crossentropy',
metrics=[
'accuracy',
tf.keras.metrics.AUC(name='auc'),
],
)
batch_size = 64
epochs = 100
history = model.fit(
X_train,
y_train,
validation_split=0.2,
batch_size=batch_size,
epochs=epochs,
verbose=0,
)
Model: "sequential_1"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
dense_3 (Dense) (None, 64) 1792
_________________________________________________________________
dense_4 (Dense) (None, 64) 4160
_________________________________________________________________
dense_5 (Dense) (None, 1) 65
=================================================================
Total params: 6,017
Trainable params: 6,017
Non-trainable params: 0
_________________________________________________________________
結果
plt.figure(figsize=(14, 10))
epochs_range = range(1, epochs+1)
train_loss = history.history['loss']
val_loss = history.history['val_loss']
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()
np.argmin(val_loss)
28
model.evaluate(X_test, y_test)
2/2 [==============================] - 0s 5ms/step - loss: 0.3192 - accuracy: 0.9118 - auc: 0.9386
[0.3191758096218109, 0.9117646813392639, 0.9385964870452881]