Using Bayesian Statistics and PyMC3 to Model the Temporal Dynamics of COVID-19 - The Databricks Blogの翻訳です。
この記事で用いられているノートブックを試してみる
本書では、COVID-19の疾病パラメータを予測するためにどのようにPyMC3を用いるのかを掘り下げます。PyMC3はベイジアンモデリングで用いられる人気のある確率論的プログラミングフレームワークです。これを達成するための人気のある二つの方法は、マルコフ連鎖モンテカルロ法と変分推論法です。ここでは、現在利用可能なアメリカにおける感染者数時系列データを用いて、コンパートメントな確率論的モデルを用いたモデリングを行います。疾病パラメータの推定にトライし、最終的にはMCMCサンプリングを用いてR0を推定します。
ここでお見せするのはデモ目的のものであり、実世界のベイジアンモデリングにおいては、ここでお見せする以上に洗練されたツールが必要となります。人口のダイナミクスに関する様々な仮説が定義されますが、大規模かつ均一ではない人口動態においては適切ではないかもしれません。また、ソーシャルディスタンスやワクチン摂取のような人的介入はここでは考慮されません。
この記事では以下をカバーします。
- 伝染病に対するコンパートメントモデル
- どこからデータがやってきてどのように投入するのか
- SIRおよびSIRSモデルの概要
- PyMC3によるODEのベイジアン推定
- Databricksにおける推定ワークフロー
伝染病に対するコンパートメントモデル
コンパートメントモデルの概要と振る舞いに関しては、Juliaのノートブックを参照してください。
コンパートメントモデルは、コンパートメント(区分)における人口の流入、流出があるものとするクローズドな人口における一連の常微分方程式(ODE)です。これらは、均一な人口動態コンパートメントにおける疾病の伝播をモデリングすることを狙いとしています。想像できるように、大規模な人口においては、この仮説は適切ではないかもしれません。また、このモデルには、当該人口において誕生、死亡という生死に関わる統計情報が含まれていないことがあることにも注意する必要があります。以下のリストでは、疾病拡大の様々なコンパートメントとともにコンパートメントモデルの一覧を示していますが、全てを網羅している訳ではありません。
- Susceptible Infected (SI)
- Susceptible Infected Recovered (SIR)
- Susceptible Infected Susceptible (SIS)
- Susceptible Infected Recovered Susceptible (SIRS)
- Susceptible Infected Recovered Dead (SIRD)
- Susceptible Exposed Infected Recovered (SEIR)
- Susceptible Exposed Infected Recovered Susceptible (SEIRS)
- Susceptible Exposed Infected Recovered Dead (SEIRD)
- Maternally-derived Immunity Susceptible Infectious Recovered (MSIR)
- SIDARTHE
上のリストの最後にあるものは、最新かつCOVID-19に特化したものですので、興味がある方は一読いただければと思います。実世界における疾病モデリングでは、感染したコンパートメントに関連する多くの仮説が存在するため、多くのケースで一つ以上の疾病ステージの時系列的進化が含まれます。どのように感染が拡大するのかを理解するために、空間離散化と人口動態を通じた疾病拡大の進化を見てみようと思います。ここでの空間、時系列的進化のモデリングフレームワークの例はGLEAM(図1)となります。
図1
人々が地理的にどのように移動するのかを理解するために、GLEAMのようなツールは国勢調査データと移動パターンを用います。GLEAMは地形をおおよそ25km x 25kmの空間グリッドに分割します。移動には大きく二つのタイプがあります:グローバル(長距離移動)とローカル(短距離移動)です。長距離移動は飛行機の移動が多く含まれ、空港が感染経路の中央ハブと考えられます。船舶による移動も重要な要素であり、港が別のアクセスポイントと考えられます。上に列挙した数学モデルとともに、このツールはパラメーター推定、予測のための数百万のシミュレーションを行うための確率論的フレームワークを提供します。
感染者数が定期的に更新されるJohns Hopkins CSSEのGithubページからデータを取得します。
データはPythonのpandasで読み込めるCSV形式で提供されています。
SIRおよびSIRSモデル
SIRモデル
SIRモデルは以下に示す3つの常微分方程式(ODE)で提供されます。
ここでは、人口N
における未感染者S
、感染者I
、回復者R
であり、以下が成り立ちます。
S + I + R = N
また、疾病から回復した個人には、生涯抗体が与えられるものとします。多くの疾病でこれは当てはまらず、適切なモデルとは言えないかもしれません。
λは感染確率であり、μは疾病からの回復確率です。感染から回復した人の割合はf
ですが、ここでは意図的に1を設定します。最終的には、I(0)がパンデミック開始時の感染者数として既知なものであり、S(0)はN-I(0)として推定できるものと仮定した、一連のODEの初期値問題(IVP)となります。ここでは人口全体が未感染との仮説を置きます。ここでのゴールは以下の通りとなります:
- λとμを推定するためにベイジアン推論を用いる
- 任意の
t
におけるI(t)を推定するために上のパラメータを使用する - R0を計算する
既に指摘したように、λは疾病の感染係数です。これは、感染者間の時間当たり接触数に依存します。これは人口における感染者数に依存します。
λ = 接触率 x 伝播確率
ある時点t
における感染力あるいはリスクはλ Ιt/Νで定義されます。また、μは単位時間における回復の割合となります。このことから、μ-1は平均回復時間となります。「基本再生産数」R0は単一の一次ケースから生じる二次ケースの平均値です(例:https://www.ncbi.nlm.nih.gov/pmc/articles/PMC6002118/)。また、R0はλとμの比率によっても定義できます。
R0 = λ/μ (S0は1に近いと仮定)
R0>1の場合、感染は拡大しておりパンデミックと言えます。最近のワクチン摂取の努力によって、これはこれまで以上に理解できるものとなっています。人口に対するワクチン摂取の割合p
が(1-p)R0<1となれば、感染拡大が抑制されたことになります。
SIRSモデル
以下に示すSIRSモデルにおいては、感染者が回復した後に生涯抗体が生まれることを仮定しません。このために、回復者コンパートメントから未感染者コンパートメントに人が移動します。このため、獲得された抗体が短期的なものであるCOVID-19においては、優れたベースラインモデルかもしれません。ここでの追加パラメーターは、抗体が失われ、回復プールから未感染者プールに移動した比率を示すγのみです。
ここでは、SIRモデルのみが実装され、SIRSモデルなどの他のモデルに関しては、今後の取り組みとなります。
PyMC3を用いた疾病パラメータの推定
ソリューションの時間を進行させるPyMC3に引き渡せるように、一次時間差分スキーム、二次時間差分スキームを用いてSIRモデルを離散化します。モンテカルロサンプリング法を用いてパラメータλ、μをフィッティングします。
一次スキーム
二次スキーム
PyMC3におけるDifferentialEquationメソッド
手動で任意の高次離散スキームを引き渡すことはできますが、これはすぐに面倒で、エラーが混入しやすいものになり、計算的にも非効率なのは言うまでもありません。幸運にも、PyMC3はこれを行うためのODEモジュールを提供しています。ODEモジュールは入力として、一連のODEの値をベクトルとして返却する関数、ソリューションが望まれる時間ステップ、解決してほしい変数の数、数式の数に対応するステージ数を受け取ります。このメソッドのデメリットの一つとして、遅くなる場合があるということがあります。推奨されるベストプラクティスは、PyMC3のsunode
モジュールを使うというものです。例えば、100サンプル、100のチューニングサンプル、20の時間ポイントにおいて、DifferentialEquationsでは5.4分かかるのに対して、sunodeでは16秒となります。
self.sir_model_non_normalized = DifferentialEquation(
func = self.SIR_non_normalized,
times = self.time_range1:],
n_states = 2,
n_theta = 2,
t0 = 0)
def SIR_non_normalized(self, y, t, p):
ds = -p[0] * y[0] * y[1] / self.covid_data.N,
di = p[0] * y[0] * y[1] / self.covid_data.N - p[1] * y[1]
return[ds, di]
sunodeモジュールを使う文法は以下との通りとなります。いくつかの文法上の違いはありますが、一般的な構造はDifferentialEquationsと同じです。
import sunode
import sunode.wrappers.as_theano
def SIR_sunode(t, y, p):
return {
'S': -p.lam * y.S * y.I,
'I': p.lam * y.S * y.I - p.mu * y.I}
...
...
sir_curves, _, problem, solver, _, _ = sunode.wrappers.as_theano.solve_ivp(
y0={ # Initial conditions of the ODE
'S': (S_init, ()),
'I': (I_init, ()),
},
params={
# Parameters of the ODE, specify shape
'lam': (lam, ()),
'mu': (mu, ()),
'_dummy': (np.array(1.), ()) # currently, sunode throws an error
}, # without this
# RHS of the ODE
rhs=SIR_sunode,
# Time points of th solution
tvals=times,
t0=times[0],
)
SIRモデルに対する推定プロセス
探索するパラメータの推定を行うために、疾病パラメータに対する適切な事前確率を選択するところから開始します。これらのパラメータの振る舞いに対する理解に基づき、対数正規分布が適切であると言えます。理想的には、この対数正規分布の平均パラメータが必要なパラメータの隣に存在してほしいです。よい収束とソリューションのために、データの尤度は適切であること(ドメイン専門知識が重要!)が重要です。以下の尤度をの一つを選択することが一般的です。
- 正規分布
- 対数正規分布
- Student t分布
ODEのソルバーから未感染者数(S(t))、感染者数(I(t))を取得し、以下の通りλとμのサンプルを取得します。
with pm.Model() as model4:
sigma = pm.HalfCauchy('sigma', self.likelihood['sigma'], shape=1)
lam = pm.Lognormal('lambda', self.prior['lam'], self.prior['lambda_std']) # 1.5, 1.5
mu = pm.Lognormal('mu', self.prior['mu'], self.prior['mu_std']) # 1.5, 1.5
res, _, problem, solver, _, _ = sunode.wrappers.as_theano.solve_ivp(
y0={
'S': (self.S_init, ()), 'I': (self.I_init, ()),},
params={
'lam': (lam, ()), 'mu': (mu, ()), '_dummy': (np.array(1.), ())},
rhs=self.SIR_sunode,
tvals=self.time_range,
t0=self.time_range[0]
)
if(likelihood['distribution'] == 'lognormal'):
I = pm.Lognormal('I', mu=res['I'], sigma=sigma, observed=self.cases_obs_scaled)
elif(likelihood['distribution'] == 'normal'):
I = pm.Normal('I', mu=res['I'], sigma=sigma, observed=self.cases_obs_scaled)
elif(likelihood['distribution'] == 'students-t'):
I = pm.StudentT( "I", nu=likelihood['nu'], # likelihood distribution of the data
mu=res['I'], # likelihood distribution mean, these are the predictions from SIR
sigma=sigma,
observed=self.cases_obs_scaled
)
R0 = pm.Deterministic('R0',lam/mu)
trace = pm.sample(self.n_samples, tune=self.n_tune, chains=4, cores=4)
data = az.from_pymc3(trace=trace)
DatabricksにおけるPyMC3を用いた推定ワークフロー
このようなベイジアン推論による疾病パラメータを推定するモデルの構築は、可能な限り自動化したいと考えるインタラクティブなプロセスとなります。様々なパラメータでモデルオブジェクトのインスタンスを取得し、自動化された処理を行うことが良いアイデアかもしれません。幸運にも、Databrikcsノートブックを用いることで、処理の自動実行は容易に実現できます。ノートブックのそれぞれのセルには必要なパラメータの組み合わせ(以下を参照)が記載されており、処理が実行されるとユーザーの介入なしにプロットが行われます。それぞれの処理実行におけるトレース情報、R̂のような推論のメトリクス、他のメタデータを保存するのも良いアイデアです。NetCDFのようなファイルフォーマットを用いることができますが、Pythonのビルトインデータベースモジュールshelve
を用いる方がシンプルです。
covid_obj = COVID_data('US', Population=328.2e6)
covid_obj.get_dates(data_begin='10/1/20', data_end='10/28/20')
sir_model = SIR_model_sunode(covid_obj)
likelihood = {'distribution': 'normal',
'sigma': 2}
prior = {'lam': 1.5,
'mu': 1.5,
'lambda_std': 1.5,
'mu_std': 1.5 }
sir_model.run_SIR_model(n_samples=500, n_tune=500, likelihood=likelihood, prior=prior)
結果サンプル
これらの結果はデモ目的のものであり、このシミュレーションから優位な結果を得るためには綿密な実験が必要となります。アメリカにおいて1月から10月までの感染者数のカウントは以下の通りとなります(図2)。
図2
図3では、λ、μ、R0の事後確率分布が表示された推論結果を示しています。ベイジアン推論を実行することのメリットの一つは、分布が定量化の不確実性を示す最高密度信用区間(HDI)とともに推定値の平均値を示すということです。サンプリングが適切に行われたことを確認するためにトレースをチェック(少なくとも一回は!)するのもグッドアイデアです。
図3
ノートおよびガイドライン
モデリング、推定に対する一般的なガイドラインです:
- 少なくとも5000サンプル、チューニングのためには1000サンプルを使用してください
- 上の結果を得るために、以下のパラメータを使いました:
- 平均値: λ = 1.5、μ = 1.5
- 標準偏差: 両方のパラメータに2.0
- 最低でも3チェーンからサンプリング
- target_acceptを > 0.85に設定
- 可能であればcores=nを設定して並列化してください。
n
は使用するコア数となります。 - 収束のトレースを調査して下さい。
- 限定的な時間サンプルは推定の精度に影響を及ぼします。より多くの高品質のデータを持つことが望ましいです。
- データを正規化してください。大きな値は一般的に収束に対してネガティブに働きます。
モデルのデバッグ
- PyMC3のバックエンドはtheanoであるため、変数値を調査するのにPythonのprint文を使うことはできません。**theano.printing.Print(DESCRIPTIVE_STRING)(VAR)**を使用してください。
-
testval
を指定して確率変数を初期化してください。これは厄介なBad Energy
エラーをチェックするのに役立ちます。これは多くの場合、尤度や事前確率の選択を間違った場合に起こります。検証するために**Model.check_test_point()**を使ってください。 - クイックなデバックのために**step = pm.Metropolis()**を使用してください。これによって、荒い事後確率ではありますが結果を高速に出力します。
- サンプリングが遅い場合には、事前確率、尤度の分布をチェックしてください。
まとめ
この記事では、疾病パラメータを取得するためにPyMC3の使い方の基礎をカバーしました。今後の記事では、Databricks環境、そして、実験のトラッッキングのためのMLflow、ハイパーパラメータ最適化のためのHyperOptのようなワークフローツールのインテグレーションを掘り下げます。
この記事で用いられているノートブックを試してみる