LoginSignup
1

More than 3 years have passed since last update.

Data Every Day: COVID-19 健康的な食事データセット

Posted at

tldr

KggleのCOVID-19 Healthy Diet DatasetPredicting COVID-19 Mortality - Data Every Day #048に沿ってやっていきます。

実行環境はGoogle Colaboratorです。

インポート

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import sklearn.preprocessing as sp
from sklearn.model_selection import train_test_split
import sklearn.linear_model as slm

import tensorflow as tf

データのダウンロード

Google Driveをマウントします。

from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive

KaggleのAPIクライアントを初期化し、認証します。
認証情報はGoogle Drive内(/content/drive/My Drive/Colab Notebooks/Kaggle)にkaggle.jsonとして置いてあります。

import os
kaggle_path = "/content/drive/My Drive/Colab Notebooks/Kaggle"
os.environ['KAGGLE_CONFIG_DIR'] = kaggle_path

from kaggle.api.kaggle_api_extended import KaggleApi
api = KaggleApi()
api.authenticate() 

Kaggle APIを使ってデータをダウンロードします。

dataset_id = 'mariaren/covid19-healthy-diet-dataset'
dataset = api.dataset_list_files(dataset_id)
file_name = dataset.files[0].name
file_path = os.path.join(api.get_default_download_dir(), file_name)
file_path
'/content/Food_Supply_Quantity_kg_Data.csv'
api.dataset_download_file(dataset_id, file_name, force=True, quiet=False)
100%|██████████| 43.4k/43.4k [00:00<00:00, 22.3MB/s]

Downloading Food_Supply_Quantity_kg_Data.csv to /content









True

データの読み込み

data = pd.read_csv(file_path)
data
Country Alcoholic Beverages Animal fats Animal Products Aquatic Products, Other Cereals - Excluding Beer Eggs Fish, Seafood Fruits - Excluding Wine Meat Milk - Excluding Butter Miscellaneous Offals Oilcrops Pulses Spices Starchy Roots Stimulants Sugar & Sweeteners Sugar Crops Treenuts Vegetable Oils Vegetables Vegetal Products Obesity Undernourished Confirmed Deaths Recovered Active Population Unit (all except Population)
0 Afghanistan 0.0014 0.1973 9.4341 0.0000 24.8097 0.2099 0.0350 5.3495 1.2020 7.5828 0.0728 0.2057 0.0700 0.2953 0.0574 0.8802 0.3078 1.3489 0.000 0.0770 0.5345 6.7642 40.5645 4.5 29.8 0.125149 0.005058 0.098263 0.021827 38928000.0 %
1 Albania 1.6719 0.1357 18.7684 0.0000 5.7817 0.5815 0.2126 6.7861 1.8845 15.7213 0.1123 0.2324 0.9377 0.2380 0.0008 1.8096 0.1055 1.5367 0.000 0.1515 0.3261 11.7753 31.2304 22.3 6.2 1.733298 0.035800 0.874560 0.822939 2838000.0 %
2 Algeria 0.2711 0.0282 9.6334 0.0000 13.6816 0.5277 0.2416 6.3801 1.1305 7.6189 0.1671 0.0870 0.3493 0.4783 0.0557 4.1340 0.2216 1.8342 0.000 0.1152 1.0310 11.6484 40.3651 26.6 3.9 0.208754 0.005882 0.137268 0.065604 44357000.0 %
3 Angola 5.8087 0.0560 4.9278 0.0000 9.1085 0.0587 1.7707 6.0005 2.0571 0.8311 0.1165 0.1550 0.4186 0.6507 0.0009 18.1102 0.0508 1.8495 0.000 0.0061 0.6463 2.3041 45.0722 6.8 25 0.050049 0.001144 0.027440 0.021465 32522000.0 %
4 Antigua and Barbuda 3.5764 0.0087 16.6613 0.0000 5.9960 0.2274 4.1489 10.7451 5.6888 6.3663 0.7139 0.2219 0.2172 0.1840 0.1524 1.4522 0.1564 3.8749 0.000 0.0253 0.8102 5.4495 33.3233 19.1 NaN 0.151020 0.005102 0.140816 0.005102 98000.0 %
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
165 Venezuela (Bolivarian Republic of) 2.5952 0.0403 14.7565 0.0000 12.9253 0.3389 0.9456 7.6460 3.8328 9.3920 0.0702 0.2078 0.0281 0.3342 0.0009 2.5643 0.1479 3.4106 0.000 0.0009 1.3734 4.1474 35.2416 25.2 21.2 0.377466 0.003351 0.359703 0.014411 28645000.0 %
166 Vietnam 1.4591 0.1640 8.5765 0.0042 16.8740 0.3077 2.6392 5.9029 4.4382 0.6069 0.0126 0.4149 0.8410 0.2032 0.2074 1.0596 0.2880 1.2846 0.815 0.3070 0.2201 11.9508 41.4232 2.1 9.3 0.001457 0.000036 0.001295 0.000126 96209000.0 %
167 Yemen 0.0364 0.0446 5.7874 0.0000 27.2077 0.2579 0.5240 5.1344 2.7871 1.8911 0.2033 0.2827 0.0893 0.8645 0.0347 1.0794 0.2199 5.0468 0.000 0.0017 1.0811 3.2135 44.2126 14.1 38.9 0.006987 0.002032 0.004640 0.000315 29826000.0 %
168 Zambia 5.7360 0.0829 6.0197 0.0000 21.1938 0.3399 1.6924 1.0183 1.8427 1.7570 0.2149 0.3048 1.8736 0.1756 0.0478 7.9649 0.0618 1.5632 0.000 0.0014 0.6657 3.4649 43.9789 6.5 46.7 0.099663 0.001996 0.094696 0.002970 18384000.0 %
169 Zimbabwe 4.0552 0.0755 8.1489 0.0000 22.6240 0.2678 0.5518 2.2000 2.6142 4.4310 0.2012 0.2086 0.4498 0.4261 0.0252 2.9870 0.1494 4.6485 0.000 0.0518 1.7103 2.3213 41.8526 12.3 51.3 0.076418 0.002079 0.064280 0.010059 14863000.0 %

170 rows × 32 columns

下準備

data.info()
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 170 entries, 0 to 169
Data columns (total 32 columns):
 #   Column                        Non-Null Count  Dtype  
---  ------                        --------------  -----  
 0   Country                       170 non-null    object 
 1   Alcoholic Beverages           170 non-null    float64
 2   Animal fats                   170 non-null    float64
 3   Animal Products               170 non-null    float64
 4   Aquatic Products, Other       170 non-null    float64
 5   Cereals - Excluding Beer      170 non-null    float64
 6   Eggs                          170 non-null    float64
 7   Fish, Seafood                 170 non-null    float64
 8   Fruits - Excluding Wine       170 non-null    float64
 9   Meat                          170 non-null    float64
 10  Milk - Excluding Butter       170 non-null    float64
 11  Miscellaneous                 170 non-null    float64
 12  Offals                        170 non-null    float64
 13  Oilcrops                      170 non-null    float64
 14  Pulses                        170 non-null    float64
 15  Spices                        170 non-null    float64
 16  Starchy Roots                 170 non-null    float64
 17  Stimulants                    170 non-null    float64
 18  Sugar & Sweeteners            170 non-null    float64
 19  Sugar Crops                   170 non-null    float64
 20  Treenuts                      170 non-null    float64
 21  Vegetable Oils                170 non-null    float64
 22  Vegetables                    170 non-null    float64
 23  Vegetal Products              170 non-null    float64
 24  Obesity                       167 non-null    float64
 25  Undernourished                163 non-null    object 
 26  Confirmed                     164 non-null    float64
 27  Deaths                        164 non-null    float64
 28  Recovered                     164 non-null    float64
 29  Active                        162 non-null    float64
 30  Population                    170 non-null    float64
 31  Unit (all except Population)  170 non-null    object 
dtypes: float64(29), object(3)
memory usage: 42.6+ KB
data.isna().sum()
Country                         0
Alcoholic Beverages             0
Animal fats                     0
Animal Products                 0
Aquatic Products, Other         0
Cereals - Excluding Beer        0
Eggs                            0
Fish, Seafood                   0
Fruits - Excluding Wine         0
Meat                            0
Milk - Excluding Butter         0
Miscellaneous                   0
Offals                          0
Oilcrops                        0
Pulses                          0
Spices                          0
Starchy Roots                   0
Stimulants                      0
Sugar & Sweeteners              0
Sugar Crops                     0
Treenuts                        0
Vegetable Oils                  0
Vegetables                      0
Vegetal Products                0
Obesity                         3
Undernourished                  7
Confirmed                       6
Deaths                          6
Recovered                       6
Active                          8
Population                      0
Unit (all except Population)    0
dtype: int64
data = data.drop('Unit (all except Population)', axis=1)
for column in data.columns:
    if data.dtypes[column] != 'object':
        data[column] = data[column].fillna(data[column].mean())
data.isna().sum()
Country                     0
Alcoholic Beverages         0
Animal fats                 0
Animal Products             0
Aquatic Products, Other     0
Cereals - Excluding Beer    0
Eggs                        0
Fish, Seafood               0
Fruits - Excluding Wine     0
Meat                        0
Milk - Excluding Butter     0
Miscellaneous               0
Offals                      0
Oilcrops                    0
Pulses                      0
Spices                      0
Starchy Roots               0
Stimulants                  0
Sugar & Sweeteners          0
Sugar Crops                 0
Treenuts                    0
Vegetable Oils              0
Vegetables                  0
Vegetal Products            0
Obesity                     0
Undernourished              7
Confirmed                   0
Deaths                      0
Recovered                   0
Active                      0
Population                  0
dtype: int64
undernourished_numeric = data.loc[data['Undernourished'] != '<2.5', 'Undernourished'].astype(np.float)
undernourished_numeric
0      29.8
1       6.2
2       3.9
3      25.0
4       NaN
       ... 
165    21.2
166     9.3
167    38.9
168    46.7
169    51.3
Name: Undernourished, Length: 126, dtype: float64
undernourished_numeric = undernourished_numeric.fillna(undernourished_numeric.mean())
undernourished_numeric = pd.qcut(undernourished_numeric, q=3, labels=[1, 2, 3])
undernourished_numeric
0      3
1      1
2      1
3      3
4      2
      ..
165    3
166    2
167    3
168    3
169    3
Name: Undernourished, Length: 126, dtype: category
Categories (3, int64): [1 < 2 < 3]
data.loc[undernourished_numeric.index, 'Undernourished'] = undernourished_numeric
data['Undernourished'] = data['Undernourished'].apply(lambda x: 0 if x == '<2.5' else x)
data['Undernourished'].value_counts()
0    44
3    42
2    42
1    42
Name: Undernourished, dtype: int64
data
Country Alcoholic Beverages Animal fats Animal Products Aquatic Products, Other Cereals - Excluding Beer Eggs Fish, Seafood Fruits - Excluding Wine Meat Milk - Excluding Butter Miscellaneous Offals Oilcrops Pulses Spices Starchy Roots Stimulants Sugar & Sweeteners Sugar Crops Treenuts Vegetable Oils Vegetables Vegetal Products Obesity Undernourished Confirmed Deaths Recovered Active Population
0 Afghanistan 0.0014 0.1973 9.4341 0.0000 24.8097 0.2099 0.0350 5.3495 1.2020 7.5828 0.0728 0.2057 0.0700 0.2953 0.0574 0.8802 0.3078 1.3489 0.000 0.0770 0.5345 6.7642 40.5645 4.5 3 0.125149 0.005058 0.098263 0.021827 38928000.0
1 Albania 1.6719 0.1357 18.7684 0.0000 5.7817 0.5815 0.2126 6.7861 1.8845 15.7213 0.1123 0.2324 0.9377 0.2380 0.0008 1.8096 0.1055 1.5367 0.000 0.1515 0.3261 11.7753 31.2304 22.3 1 1.733298 0.035800 0.874560 0.822939 2838000.0
2 Algeria 0.2711 0.0282 9.6334 0.0000 13.6816 0.5277 0.2416 6.3801 1.1305 7.6189 0.1671 0.0870 0.3493 0.4783 0.0557 4.1340 0.2216 1.8342 0.000 0.1152 1.0310 11.6484 40.3651 26.6 1 0.208754 0.005882 0.137268 0.065604 44357000.0
3 Angola 5.8087 0.0560 4.9278 0.0000 9.1085 0.0587 1.7707 6.0005 2.0571 0.8311 0.1165 0.1550 0.4186 0.6507 0.0009 18.1102 0.0508 1.8495 0.000 0.0061 0.6463 2.3041 45.0722 6.8 3 0.050049 0.001144 0.027440 0.021465 32522000.0
4 Antigua and Barbuda 3.5764 0.0087 16.6613 0.0000 5.9960 0.2274 4.1489 10.7451 5.6888 6.3663 0.7139 0.2219 0.2172 0.1840 0.1524 1.4522 0.1564 3.8749 0.000 0.0253 0.8102 5.4495 33.3233 19.1 2 0.151020 0.005102 0.140816 0.005102 98000.0
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
165 Venezuela (Bolivarian Republic of) 2.5952 0.0403 14.7565 0.0000 12.9253 0.3389 0.9456 7.6460 3.8328 9.3920 0.0702 0.2078 0.0281 0.3342 0.0009 2.5643 0.1479 3.4106 0.000 0.0009 1.3734 4.1474 35.2416 25.2 3 0.377466 0.003351 0.359703 0.014411 28645000.0
166 Vietnam 1.4591 0.1640 8.5765 0.0042 16.8740 0.3077 2.6392 5.9029 4.4382 0.6069 0.0126 0.4149 0.8410 0.2032 0.2074 1.0596 0.2880 1.2846 0.815 0.3070 0.2201 11.9508 41.4232 2.1 2 0.001457 0.000036 0.001295 0.000126 96209000.0
167 Yemen 0.0364 0.0446 5.7874 0.0000 27.2077 0.2579 0.5240 5.1344 2.7871 1.8911 0.2033 0.2827 0.0893 0.8645 0.0347 1.0794 0.2199 5.0468 0.000 0.0017 1.0811 3.2135 44.2126 14.1 3 0.006987 0.002032 0.004640 0.000315 29826000.0
168 Zambia 5.7360 0.0829 6.0197 0.0000 21.1938 0.3399 1.6924 1.0183 1.8427 1.7570 0.2149 0.3048 1.8736 0.1756 0.0478 7.9649 0.0618 1.5632 0.000 0.0014 0.6657 3.4649 43.9789 6.5 3 0.099663 0.001996 0.094696 0.002970 18384000.0
169 Zimbabwe 4.0552 0.0755 8.1489 0.0000 22.6240 0.2678 0.5518 2.2000 2.6142 4.4310 0.2012 0.2086 0.4498 0.4261 0.0252 2.9870 0.1494 4.6485 0.000 0.0518 1.7103 2.3213 41.8526 12.3 3 0.076418 0.002079 0.064280 0.010059 14863000.0

170 rows × 31 columns

フィーチャーとターゲットの選択

data = data.drop('Country', axis=1)
data = data.drop(['Confirmed', 'Recovered', 'Active'], axis=1)
pd.qcut(data['Deaths'], q=2, labels=[0, 1]).value_counts()
1    85
0    85
Name: Deaths, dtype: int64
data['Death'] = pd.qcut(data['Deaths'], q=2, labels=[0, 1])

分割とスケーリング

y = data['Death']
X = data.drop('Death', axis=1)
scaler = sp.StandardScaler()
X = scaler.fit_transform(X)
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8)

トレーニング

X.shape
(170, 27)
model = tf.keras.Sequential([
    tf.keras.layers.Dense(64, activation='relu', input_shape=(27,)),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(1, activation='sigmoid'),
])

model.summary()

model.compile(
    optimizer='adam',
    loss='binary_crossentropy',
    metrics=[
             'accuracy',
             tf.keras.metrics.AUC(name='auc'),
             ],
)

batch_size = 64
epochs = 100

history = model.fit(
    X_train,
    y_train,
    validation_split=0.2,
    batch_size=batch_size,
    epochs=epochs,
    verbose=0,
)
Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_3 (Dense)              (None, 64)                1792      
_________________________________________________________________
dense_4 (Dense)              (None, 64)                4160      
_________________________________________________________________
dense_5 (Dense)              (None, 1)                 65        
=================================================================
Total params: 6,017
Trainable params: 6,017
Non-trainable params: 0
_________________________________________________________________

結果

plt.figure(figsize=(14, 10))

epochs_range = range(1, epochs+1)
train_loss = history.history['loss']
val_loss = history.history['val_loss']

plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')

plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.show()

COVID_19_Healthy_Diet_Dataset_38_0.png

np.argmin(val_loss)
28
model.evaluate(X_test, y_test)
2/2 [==============================] - 0s 5ms/step - loss: 0.3192 - accuracy: 0.9118 - auc: 0.9386





[0.3191758096218109, 0.9117646813392639, 0.9385964870452881]

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1