Productionizing Machine Learning: From Deployment to Drift Detection - The Databricks Blogの翻訳です。
多くのブログ記事において、機械学習のワークフローはデータの準備から始まり本番環境へのモデルデプロイで終わります。しかし実際には、それは機械学習モデルのライフサイクルの初めの一歩に過ぎないのです。”人生において変化は起こり続けるものだ”という人もいます。デプロイ後しばらくして、モデルドリフトと呼ばれるモデルの精度劣化が発生するため、これは機械学習モデルにおいても真実と言えます。本記事ではモデルドリフトを検知し対策するのかを説明します。
機械学習におけるドリフトの種別
特徴データや目標変数の依存性の変化があった際にモデルドリフトが起こり得ます。我々は、これらの変化を3つのカテゴリに分類します:概念ドリフト、データドリフト、上流データの変化です。
概念ドリフト(concept drift)
目標変数の統計的属性が変化した時、予測しようとする本当の概念もまた変化します。例えば、不正トランザクションにおいては、新たな手口が生まれてくると、不正の定義自体を見直さなくてはなりません。このような変化は概念ドリフトを引き起こします。
データドリフト(data drift)
入力データから選択された特徴量を用いてモデルをトレーンングします。入力データの統計的特性に変化が生じた際、モデルの品質に影響を及ぼします。例えば、季節性によるデータの変化、個人的嗜好の変化、トレンドなどは入力データのドリフトを引き起こします。
上流データの変化(upstream data changes)
モデル品質に影響を与えうるデータパイプライン上流でのオペレーションの変更が生じる場合があります。例えば、特徴量のエンコーディングにおいて華氏から摂氏に変更があったり、特徴量の生成が停止されることでnullや欠損値になるなどです。
モデルドリフトの検知及び対策
モデルが本格稼働した後でもこのような変更が起こるのであれば、あなたが取るべきベストな選択肢は、変更を監視し、変更が起きた場合に対策を取るということです。モニタリングシステムからのフィードバックループを持ち、長きにわたってモデルをリフレッシュすることでモデルが陳腐化を避けることができます。
上で見たように、様々な原因からドリフトが起こるので、漏れがないように原因となりうる事象をモニタリングする必要があります。以下のシナリオに基づきモニタリングを行うことができます:
トレーニングデータ
- スキーマ、入力データの分布
- ラベルの分布
リクエスト、予測
- スキーマ、リクエストの分布
- 予測の分布
- 予測の品質
Databricksによるモデルドリフト対応
Delta Lakeによるデータドリフト検知
データの品質は、モデルドリフトとモデル品質低下に対する最初の防衛戦となります。Delta Lakeのスキーマ適用、データタイプ、期待品質(quality expectation)などの機能によって高品質、高信頼のデータパイプラインを構築することができます。エラーのあるラベルを削除したり、スキーマを修正・進化させることで、入力データパイプラインを更新し、データ品質、適切性の問題を修正することができます。
Databricks MLランタイム、MLflowによるモデルドリフト、概念ドリフトの検知
モデルドリフトを検知する一般的な方法は予測の品質モニタリングを行うことです。理想的な機械学習モデルのトレーニングは、Delta Lakeのようなデータソースからデータを読み込み、特徴量エンジニアリングを実施し、MLflowによるトラッキングを行いながら、Databricks MLランタイム上でモデルのチューニング、選択を行うという手順を踏むのでしょう。
デプロイメントの段階では、予測を行うためにモデルがMLflowから読み込まれます。パフォーマンスモニタリングや下流のシステムで利用できるように、モデルのパフォーマンス指標や予測結果をDelta Lakeのようなストレージに格納することができます。トレーニングデータ、パフォーマンス指標、予測結果を一つの場所にまとめて格納することで、正確なモニタリングを実現できます。
教師ありトレーニングの際には、モデルの品質を評価するためにトレーニングデータから特徴量とラベルを活用します。モデルがデプロイされたら、二種類のデータを記録しモニタリングします:モデルパフォーマンス指標とモデル品質指標です。
- モデルパフォーマンス指標 推論時間、メモリ消費量などのモデルの技術的側面を示す指標です。Databricksにモデルをデプロイすることで、これらの指標を容易に記録し、モニタリングすることができます。
-
モデル品質指標 この指標は実際のラベルに依存します。ラベルが記録されれば、予測したラベルと実際のラベルを比較することで、品質指標を計算でき、モデルの予測品質におけるドリフトを検知できます。
以下に示す、アーキテクチャの例においては、Delta Lakeからのストリームとして、IoTセンサーからの値(特徴量)、実際の製品品質(ラベル)を読み取ります。このデータを用いて、IoTセンサーデータから製品品質を予測するモデルを構築できます。MLflowにデプロイされたモデルはスコアリングパイプラインに読み込まれ、製品品質を予測値(予測ラベル)を取得します。
ドリフトをモニタリングするために、実際の製品品質(ラベル)と予測品質(予測ラベル)を結合し、タイムウィンドウごとに集計を行い、モデル品質の時系列トレンドに要約します。モデル品質をモニタリングするためのサマリーKPIは、ビジネスニーズによって変化し、十分な網羅性を持つために複数のKPIが計算されます。例として以下のコードスニペットを参照ください。
def track_model_quality(real, predicted):
# 実際のラベルと予測ラベルを結合
quality_compare = predicted.join(real, "pid")
# 予測モデルが正確かどうかを示す列を作成
quality_compare = quality_compare.withColumn(
'accurate_prediction',
F.when((F.col('quality')==F.col('predicted_quality')), 1)\
.otherwise(0)
)
# タイムウィンドウごとの正確な予測の割合のトレンドに要約
accurate_prediction_summary = (quality_compare.groupBy(F.window(F.col('process_time'), '1 day').alias('window'), F.col('accurate_prediction'))
.count()
.withColumn('window_day', F.expr('to_date(window.start)'))
.withColumn('total',F.sum(F.col('count')).over(Window.partitionBy('window_day')))
.withColumn('ratio', F.col('count')*100/F.col('total'))
.select('window_day','accurate_prediction', 'count', 'total', 'ratio')
.withColumn('accurate_prediction', F.when(F.col('accurate_prediction')==1, 'Accurate').otherwise('Inaccurate'))
.orderBy('window_day')
)
return accurate_prediction_summary
予測ラベルに対して実際のラベルの到着がどのくらい遅延するのかによって、これは重要な遅延を示すインジケータにもなり得ます。ドリフトの早期警戒を実現するために、このインジケータは予測品質ラベルの分布のような他のインジケータと組み合わせることもできます。誤検知を避けるために、これらのKPIはビジネス文脈に合わせて設計される必要があります。
ビジネスニーズと照らし合わせて許容できる範囲に、予測精度サマリートレンドの制御リミットの中に設定することもできます。このサマリーは標準的な統計的なプロセス管理手法でモニタリングすることができます。トレンドがこの制御リミットの外に出た際には、警告あるいは新たなデータによる新たなモデルを再作成するなどのアクションをとることができます。
次のステップ
Githubリポジトリーにある指示に従って、上の例を再現し、自身のユースケースに当てはめて下さい。文脈をより理解するためには、ウェビナー「Productionizing Machine Learning – From Deployment to Drift Detection」を参照ください。