LoginSignup
2
3

More than 3 years have passed since last update.

Chainer と tflite でグラフ内のデータをダンプするメモ

Last updated at Posted at 2019-07-15

背景

Chainer で学習し TensorFlow-lite で推論するための準備メモ
https://qiita.com/syoyo/items/e84d66b2a938843c1594

により, Chainer で学習したモデルを tflite に変換して推論したが, 結果が合わない.
単体の function/op 単位では chainer2tflite でテストして一致しているが, 複数の function/op を評価するところでなにかおかしいようである.

推論時の, function/op の個々の結果(output)をダンプして比較できるようにしたい.

Chainer

Chainer では, FunctionHook 機能により, forward などの実行時にフック関数(コールバック関数)を使うことができます.

with 構文を使ってフックを実現(with で呼ばれたときに Chainer のグローバルコンテキスト(?)に Hook を登録する)するため, モデルの定義をいじる必要がないのが特徴です.

forward_postprocess をフック関数として使うのがよいですが, 現状ではフック関数に input しか渡ってこないため, forward を呼び出し output を取得する必要があります(内部 state を変えるようなものがあると, forward を二回呼び出すことになるのでうまくいかないかも)

output を forward_postprocess に渡す PR がありますが, なかなかマージされていません.

Chainer-TRT

Chainer-TRT では, 上記 FunctionHook を使い, output を retain しておき, 推論が終わったあとにグラフをトラバースして中間の結果をダンプしています.

とりあえずテンソル(weight)データだけダンプできればよければ, chainer-trt を使うのが手っ取り早そうです. graph の情報は JSON で, weight の情報は簡単なフォーマットでバイナリ形式で出力されます.

ModelRetriever のコンストラクタには verbose=True を指定しましょう. 途中のテンソル(各 function/link の output)を出力してくれます. また, 生成されるモデルの JSON も indent が有効になり読みやすくなります.

tflite

グラフ定義で, 中間計算結果のバッファ(テンソル)を使い回していなければ(たとえばメモリ消費を抑えるために使いまわしているとか), 推論したあとにテンソル(バッファ)のデータをダンプすればよいです.

Python

interpreter の get_tensor_details() でモデル内のテンソル情報を取得できます.

get_tensor(tensor_id) で, テンソルの配列(numpy.ndarray)を取得できます.

invoke したあとに, これら関数を呼んでダンプしましょう.

C++

C++ の場合も, Python のように tensor のリストをダンプすればよいでしょう.

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3