LoginSignup
3
7

More than 3 years have passed since last update.

COVID-19を予測してみる

Posted at

はじめに

新型コロナの感染者数が増えてきて、東京、大阪、京都、兵庫で緊急事態宣言が発出されました。感染者数の情報は厚生労働省のページからダウンロードすることができます。まずは陽性者数のデータをダウンロードしてグラフ化してみます。

import pandas as pd
import matplotlib.pyplot as plt
import japanize_matplotlib
import seaborn as sns


def get_data():
    url = 'https://www.mhlw.go.jp/content/pcr_positive_daily.csv'
    df = pd.read_csv(url,
                     usecols=[0, 1],
                     names=['date', 'positives'],
                     skiprows=1,
                     parse_dates=['date'],
                     index_col='date'
                     )
    return df


def main():
    df = get_data()
    df.plot()
    plt.title('COVID19 日別感染者数')
    plt.show()


if __name__ == '__main__':
    main()

covid19-graph.png

素人の私が見ても、これは第四波が来てると思ってしまいます。
この先どうなるのでしょう?
今回はProphetとARIMAを使用して、1ヶ月後までの感染者推移数を予測してみます。実はLSTMを使ったモデルでも試してみたのですが、うまく予想できなかったため別の方法で試してみることにしました。

Prophetで予測してみる

ProphetはFacebookが開発した予測モデルです。

使い方に特長がありますが、モデルを作るのは簡単です。
入力データにはPandasのDataFrameで指定します。このとき時系列の列名をds、測定値の列名をyとしておく必要があります。更にcapという列を追加する必要があり、ここには取り得る最大値を指定します。今回は10000としておきます。データの取得とDataFrameの作成は下記のようになります。

    url = 'https://www.mhlw.go.jp/content/pcr_positive_daily.csv'
    df = pd.read_csv(url,
                     usecols=[0, 1],
                     names=['ds', 'y'],
                     skiprows=1,
                     parse_dates=['ds'],
                     )
    df['cap'] = 10000

モデルの作成は下記のように行います。

    model = Prophet()

何も指定しないと線形回帰になるので、今回は下記のようにlogistic回帰を指定します。

    model = Prophet(growth='logistic')

あとはデータを指定してフィッティングさせます。

    model.fit(df)

将来の予想を行うには、将来の日付を指定したDataFrameを作成します。make_future_dataframeというメソッドを使用すると、指定した先までのデータフレームを作ってくれます。ここでもcap列を追加しておく必要があるので注意です。

    df_future = model.make_future_dataframe(periods=30)
    df_future['cap'] = 10000

予測値を求めてみます。

    predicts = model.predict(df_future)

下記のようにすると、結果をグラフで表示できます。
plt.xkcd()とすると、手書き風のグラフを作成できます。(すみません、やってみたかっただけです。)

    plt.xkcd()
    model.plot(predicts)
    plt.tight_layout()
    plt.title('Prediction COVID-19 by Prophet Model')
    plt.show()

covid19-prophet2.png

増加傾向にあるのはわかります。1週間の周期性を持っているのもわかります。でも、実際の値にはフィットしていませんね...

ARIMAで予測してみる

次はARIMAで予測してみます。感染者数は1週間の周期性があるのがわかっているので、今回はSARIMA(Seasonal ARIMA model)を使ってみます。
ARIMAモデルを理解しておらず説明がうまくできないので、ソースコードだけ書いておきます。

import datetime
import pandas as pd
from matplotlib import pylab as plt
import japanize_matplotlib
import statsmodels.api as sm


def get_data():
    url = 'https://www.mhlw.go.jp/content/pcr_positive_daily.csv'
    df = pd.read_csv(url,
                     usecols=[0, 1],
                     names=['date', 'positives'],
                     skiprows=1,
                     parse_dates=['date'],
                     index_col='date'
                     )
    return df


def main():
    plt.xkcd()

    df = get_data()
    diff = df['positives'].diff()
    diff = diff.dropna()

    params = sm.tsa.arma_order_select_ic(diff, ic='aic', trend='nc')
    aic_order = params['aic_min_order']

    '''
    orderはarma_order_select_icで求めた値を指定。
    seasonal_orderは1週間周期なので4番目に7を指定。
    '''
    model = sm.tsa.SARIMAX(
        df,
        order=(aic_order[0], 1, aic_order[1]),
        seasonal_order=(1, 1, 1, 7)
    ).fit()

    '''
    30日分の予想をしてみる
    '''
    predict_period_from = df.index.max()
    predict_period_to = df.index.max() + datetime.timedelta(days=30)

    predict = model.predict(predict_period_from, predict_period_to)
    plt.plot(df, label='real')
    plt.plot(predict, label='predict')
    plt.title('Prediction COVID-19 by SARIMA Model')
    plt.savefig('covid19-arima2.png')
    plt.show()


if __name__ == '__main__':
    main()

covid19-arima2.png

なんか、いい感じで予測できていますね。

終わりに

今回はDeepLearningをあきらめて、別の方法でCOVID19の予測をしてみました。
ARIMAモデルのすごさもわかりました。DeepLearningと並行して勉強していきたいと思います。
ソースはGitHubに置いています。

3
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
7