LoginSignup
138
138

More than 5 years have passed since last update.

あなたの機械学習システム構築を手助けする、TensorFlow Extended

Last updated at Posted at 2018-12-14

今日では、機械学習が研究者だけでなく個人レベルで利用できるような時代になってきました。これは、計算機の性能向上や機械学習フレームワークなど開発環境の充実、大量データが手に入りやすくなってきたことなどが要因として挙げられます。

一方、機械学習を用いたシステム(以後本記事では機械学習システムと呼びます)の構築にはハードルがあります。データ傾向の変化など、これまでのシステムにない考慮すべき点が多く存在するからです。2015年の論文においては機械学習モデル作成は一部分でしかなく、運用においてはその他の要素が大きく影響すると述べられていますが、現在でも状況は大きく変わっていないように感じます。

machine_learning_platform.png

出展:https://dl.acm.org/citation.cfm?id=3098021

本記事ではGoogleが提供する機械学習システムの開発プラットフォームであるTensorFlow Extended(TFX)を紹介します。これは上記の図の「Forcus this paper」にあたる機械学習システム構築時に必要となる処理、機能を提供するプラットフォームであり、以下に提示したライブラリで構成されます。

  • TensorFlow Data Validation
  • TensorFlow Transform
  • TensorFlow Model Analysis
  • TensorFlow Serving

本記事ではTFXのチュートリアルのコードやJupyter notebook で実行した結果を交えながら説明を行っていきます。なお、Servingについては紹介記事が多くあるので今回は触れません。

TensorFlow Data Validation

TensorFlow Data Validation は与えられたデータについて、統計量の可視化やデータの検証を行うためのライブラリです。機械学習において与える訓練データの傾向を把握することは大切です。カラムのデータが数値なのか文字列なのかや必ずデータが存在するのか、数値の場合は正規化を行うべきなのかなど確認・考慮すべき点が多くあるためです。このライブラリでは与えられたデータを読み込んで特徴を可視化を実現します。

機能としては大きく分けて3つあり、データの傾向を掴む統計量の確認、与えられるデータの構造であるスキーマ推定、スキーマを用いた異常値の検出を行う事ができる。

統計量の確認

一般的な統計量は以下のようなコードで求めることができます。実行結果は図のようになります。

# Compute stats over training data.
train_stats = tfdv.generate_statistics_from_csv(data_location=os.path.join(TRAIN_DATA_DIR, 'data.csv'))

# Visualize training data stats.
tfdv.visualize_statistics(train_stats)

visualize.png

カラムごとに取得できる統計量は以下です。

数値データ

  • 出現数
  • 欠損率
  • 平均
  • 標準偏差
  • 値ゼロ率
  • 最小値
  • 中央値
  • 最大値

カテゴリデータ

  • 出現数
  • 欠損率
  • ユニーク数
  • 最多出現単語
  • 最多出現単語の出現回数
  • 単語平均長

スキーマ推定

次に、獲得した統計量を利用してこのデータセットのスキーマを推定ができます。訓練データを用いて正しいと思われるデータ構造を推定することでデータの特徴を把握するとともに、後述するエラーを検出に利用します。ただし、推定したスキーマ定義が正しいのか、人間の目でちゃんと確認する事が強く推奨されています。(本当にrequired なカラムであるか、optional なカラムであるかなど)

# Infer a schema from the training data stats.
schema = tfdv.infer_schema(statistics=train_stats, infer_feature_shape=False)
tfdv.display_schema(schema=schema)

上記コードの実行結果は以下のようになります。

scheme.png

推論されたスキーマは以下形式で出力されます。

  • データ数
  • データタイプ
  • データの必須・オプションの区分
  • Valency
  • ドメイン名

カテゴリデータの場合はドメイン名ごとにすべてのカテゴリが表示されます。

domain.png

評価データの異常値確認

これまで訓練データを対象としてきましたが、評価データについても分析が行えます。評価データに対しても統計量を計測し、訓練データの傾向と違いがないか比較することができます。

compare.png

また、前節で作成したスキーマを用いて評価データ中の異常値(これまで見られなかった値など)を確認することができます。以下実行結果を例にすると特徴名「payment_type」において訓練データでは見られなかった「Prcard」が出現していることが確認できます。

# Check eval data for errors by validating the eval data stats using the previously inferred schema.
anomalies = tfdv.validate_statistics(statistics=eval_stats, schema=schema)
tfdv.display_anomalies(anomalies)

anomalies.png

もし予見できる(異常値ではない)データであるならば、対象データを指定したり、min_domain_massを設定してスキーマをアップロードすることによって正常データに含ませることができます。
(ソースによるとmin_domain_massは異常値を許容するしきい値であり、は1(デフォルト)ならばドメイン内のデータすべてが訓練データに含まれていないといけない、0.9なら90%はドメインに含まれていないといけないと判定するようです。)

逆に予期していなかったエラーデータであれば、学習すべきデータが含まれていなかったり異常であるということになるので、訓練データや評価データの見直しを図る必要があるでしょう。

異常値として検出されるエラータイプは以下から確認できます。
https://github.com/tensorflow/metadata/tree/master/tensorflow_metadata/proto/v0/anomalies.proto

補足:環境に応じたスキーマ設定

例えば推定したい正解ラベルは訓練データのみに存在するように、訓練データと検証や本番データのカラムは常に一致するとは限りません。このままでは異常値として検出されてしまうため環境ごとにスキーマを用意し、チェックしなくてよいカラムを除外することで希望とするカラムだけ異常値検出することができます。

このように TensorFlow Data Validation では与えるデータの特徴についてフォーカスし、学習や推論を実行するまでのプロセスをサポートしてくれます。

TensorFlow Transform

このライブラリは前処理を行うためのライブラリです。前処理とはモデルにデータを与える前に何らかの処理を施すことを指します。前処理を行う効果や必要性はこちらの記事が参考になります。

前処理一例としては単語をインデックスへの変換するなどがありますが、これは学習時も推論時も同一の処理を行う必要があります。開発を行っていると、前処理は時間がかかるためとモデルの学習と別々にしてしまうことがあり(私だけ?)、いざ本番での実行を見越した際に前処理とモデルをつなげるコードが新たに必要となってしまうなんてことがあります。そこでTensorFlow Transform は代表的な前処理方法を提供することでモデルと前処理を近づけて、End-to-End(生のデータから推論結果を得る)処理の実現をしやすくしてます。

# Transform training data
preprocess.transform_data(input_handle=os.path.join(TRAIN_DATA_DIR, 'data.csv'),
                          outfile_prefix=TFT_TRAIN_FILE_PREFIX, 
                          working_dir=get_tft_train_output_dir(0),
                          schema_file=get_schema_file(),
                          pipeline_args=['--runner=DirectRunner'])
print('Done')

実行できる前処理の例としては以下などあります。

  • 平均値を標準偏差を用いた正規化
  • テキストデータに対してボキャブラリを作成し、インデックスへの変換
  • データ分布に基づき、浮動小数点の値を整数に変換

公式サイトのAPIを確認すると他にも様々あるので見てみてください。
どのカラムにどのような前処理を行うのかは自身で明示的に実装する必要があります。上記の例ではpreparocess内で様々な前処理が行われています。

TensorFlow Transform によって学習時と推論時の前処理適用が統一できるようになり、処理実行の負担軽減が期待できそうです。

TensorFlow Model Analysis

このライブラリでは作成したモデルの評価を行うことができます。従来でも訓練データや評価データに対して作成したモデルの正解率などを計測することは、もちろん行われてきました。一方で、推論がうまくいかない特定のデータに対して分析するなど、個別に集計して確認することはなかなか手間がかってしまいます。

TensorFlow Model Analysis はモデル評価を様々な切り口(以降では観点と呼称します)で手軽にできるようにします。このことにより、モデル性能の改善ヒントを得られるかもしれません。

以降では3つの分析方法を紹介します。

Visualization: Slicing Metrics

カラムの種類や値などを観点として分析する方法です。例えばあるカラムの値ごとに正解率を算出し、グラフ化することができます。観点の指定の例としては以下のようになります。

# An empty slice spec means the overall slice, that is, the whole dataset.
OVERALL_SLICE_SPEC = tfma.SingleSliceSpec()

# Data can be sliced along a feature column
# In this case, data is sliced along feature column trip_start_hour.
FEATURE_COLUMN_SLICE_SPEC = tfma.SingleSliceSpec(columns=['trip_start_hour'])

# Data can be sliced by crossing feature columns
# In this case, slices are computed for trip_start_day x trip_start_month.
FEATURE_COLUMN_CROSS_SPEC = tfma.SingleSliceSpec(columns=['trip_start_day', 'trip_start_month'])

# Metrics can be computed for a particular feature value.
# In this case, metrics is computed for all data where trip_start_hour is 12.
FEATURE_VALUE_SPEC = tfma.SingleSliceSpec(features=[('trip_start_hour', 12)])

# It is also possible to mix column cross and feature value cross.
# In this case, data where trip_start_hour is 12 will be sliced by trip_start_day.
COLUMN_CROSS_VALUE_SPEC = tfma.SingleSliceSpec(columns=['trip_start_day'], features=[('trip_start_hour', 12)])

ALL_SPECS = [
    OVERALL_SLICE_SPEC,
    FEATURE_COLUMN_SLICE_SPEC, 
    FEATURE_COLUMN_CROSS_SPEC, 
    FEATURE_VALUE_SPEC, 
    COLUMN_CROSS_VALUE_SPEC    
]

「FEATURE_COLUMN_SLICE_SPEC」 のように「trip_start_hour」の値毎に集計をおこなったり、「COLUMN_CROSS_VALUE_SPEC」 のように「trip_start_hour」の値毎かつ「trip_start_hour」が12のものといったように計測したい条件を組み合わせることが可能です。

trip_start&trip_start_hour.png

Visualization: Plots

ROC Curve やPrediction-Recall Curve など用意されたプロット方法を用いて特定の観点の分析を行う事ができます。以下は「trip_start_hour」カラムが0のものに対してプロットした結果です。

plot.png

Visualization: Time Series

機械学習は与えるデータの傾向が変わると性能も変化する可能性があります。また、モデルの改良を加えることによって性能が向上することもあれば、低下することもあります。そこで与えるデータの変化や同一データに対して適用させるモデルが変化によって、機械学習の性能がどのように影響を受けたのか記録、確認する事は重要です。Time Series ではこの時系列で性能がどのように変化していったのかを確認することができます。
以下はモデルのパラメータを変更した3つのモデルをプロットしたものです。(見づらいですが10桁の数字がモデル番号です。)ここではaccuracyとaucについてプロットしてますが他にもaverage lossなど様々な指標で確認することができます。

timeline_graph.png

TensorFlow Analyze による3つの分析を用いることによって手間がかかっていた詳細な分析を手軽にできるようになり、より効率的に機械学習の開発が行えるようになるかもしれません。

まとめ

ここまで紹介したようにTensorFlowは機械学習モデルの構築だけでなく、周辺の機能も多く提供しています。年末にはTensorFlow 2.0のリリースもあるとのことなので今後も要チェックです!

機械学習や自然言語処理について、つぶやいてますのでフォローしていただけると嬉しいです。あとブログもやってますのでよろしくお願いします!
@kamujun18
Technical Hedgehog

Reference

https://ai.googleblog.com/2017/02/preprocessing-for-machine-learning-with.html
https://medium.com/tensorflow/introducing-tensorflow-data-validation-data-understanding-validation-and-monitoring-at-scale-d38e3952c2f0
https://medium.com/tensorflow/introducing-tensorflow-model-analysis-scaleable-sliced-and-full-pass-metrics-5cde7baf0b7b
https://www.youtube.com/watch?v=vdG7uKQ2eKk

138
138
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
138
138