3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

AxonでResNet18を我流でやってみた【設計者編】

Last updated at Posted at 2022-11-09

0.Prologue

先日投稿した同タイトルの記事の続編。裏方の AxonInterp(試作品)の簡単な紹介です。

1.はじめにコードありき - 設計指針

AxonInterpは次の要件を満たすように設計しました。

  1. DNNモデルの推論実行とそれに関連する前処理・後処理を一塊のコードとして扱いたい。
  2. 一つのアプリから複数種類のDNNモデルを利用できるようにしたい。
  3. DNNモデルのロードや初期化はアプリの起動時に一度行うだけで良いようにしたい。
  4. どのDNNモデルの推論においても変わらない処理は共通関数にして使い回したい。

つまり、DNNモデルの推論を概ね下の疑似コードの様に書きたいと言うことです。

defmodule DNN推論モジュールその1 do
 # 起動時に一度だけ
  AxonInterp.モデルのロード&初期化 model: "<model file>"

  def apply(inputs) do
    # 前処理なんたらかんたら

    # DNN推論を実行
    outputs =  # 共通関数を用意
      |> AxonInterp.入力データをセット(0, input0)
      |> AxonInterp.入力データをセット(1, input1)
      :
      |> AxonInterp.推論を実行()

    output0 = AxonInterp.推論結果の取り出し(outputs, 0)
    output1 = AxonInterp.推論結果の取り出し(outputs, 1)
    :

    # 後処理なんたらかんたら
  end
end

defmodule DNN推論モジュールその2 do
 :
end

これらの要件を満たすには、仮にオブジェクト指向型言語を用いるのであれば、共通メソッドを持つベース・クラスと、それぞれのDNNモデルの推論用にカスタマイズした派生クラスとして表現できるでしょう。一方、関数型言語である Elixirはオブジェクト指向のクラス継承の機能を持ちませんが、それに似たことを GenServerと use/__using__マクロを用いれば実現できそうです。GenServerで 各DNNモデルに固有の情報(オブジェクト)を持たせ、use/__using__マクロで共通な関数を配布する(継承)と言った設計をにすれば良いように思います。

2.薄~~いラッパーモジュールAxonInterp

前章で掲げた設計指針に沿って実装した AxonInterpは下のようなコードになりました(主要部のみ抜粋)。以下要件3),4)にフォーカスして簡単に仕組みをみていきます。

17行目から始まる __using__マクロのコード・ブロック(至82行)が、バックボーンに当たる仕掛けです。このコード・ブロックは DNN推論モジュールに書かれた useマクロでまるまる取り込まれて、そのモジュールの下で GenServerの初期化関数やハンドラがセットアップされます。

この GenServerは DNN推論モジュールの専用の GenServerなので、DNNモデル・ファイルから Axon.deserialize & Axon.buildが生成した predict_fn関数や paramsなど DNNモデルに固有の情報(状態)を持たせます(24行目)。こうすることで、いつでも predict_fn関数や paramsを利用できるようにしています(初期化関数 init/1: 34~49行目)。

また、GenServerハンドラもDNN推論モジュールの専用なので、DNN推論モジュールの名前(atom)を付してGenServer.call/3を呼ぶと、そのモジュールに紐づけされたハンドラが自動的に呼びだされます。これにより、一種のポリモフィズムな呼び出しが実現できます(96行目、126行目)

以上がバックボーンの仕組みで、要件3),4)の実現方法です:sunglasses:

  1: defmodule AxonInterp do
  2:   @timeout 300000
  3:
  4:   # バックエンドのフレームワーク名
  5:   @framework "Axon"
  6:
  7:   # @frameworkのモデル・ファイルの拡張子一覧
  8:   suffix = %{
  9:    "axon" => [".axon", ".onnx"]
 10:   }
 11:   @model_suffix suffix[String.downcase(@framework)]
 12:  
 13:   # セッション情報 - 推論リクエスト毎の入出力情報を保持するレコード
 14:   defstruct module: nil, inputs: %{}, outputs: %{}
 15:
 16:   # GenServerの初期化、ハンドラの配布
 17:   defmacro __using__(opts) do
 18:     quote generated: true, location: :keep do
 19:       use GenServer
 20:
 21:       # アクティベイト
 22:       def start_link(opts) do
 23:         # useしたモジュールに GenSeverを紐づける -> モジュール毎に個別の状態を持たせる
 24:         GenServer.start_link(__MODULE__, opts, name: __MODULE__)
 25:       end
 26:
 27:       @doc """
 28:       初期化 - モデル・ファイルの読み込みと準備
 29:         :model - モデル・ファイルの格納先PATH
 30:         :url - モデル・ファイルが上記PATHに見つからなかった場合に探しに行くURL
 31:         :inputs - モデルの入力tensorの仕様 ex) [f32: {1,3,224,224}, f32: {80}] 
 32:         :outputs - モデルの入力tensorの仕様 ex) [f32: {1, 1000}] 
 33:       """
 34:       def init(opts) do
 35:         opts = Keyword.merge(unquote(opts), opts)
 36:         nn_model   = AxonInterp.validate_model(Keyword.get(opts, :model), Keyword.get(opts, :url))
 37:         nn_inputs  = Keyword.get(opts, :inputs, [])
 38:         nn_outputs = Keyword.get(opts, :outputs, [])
 39:
 40:         # Axonモデルをビルド
 41:         {model, params} = case Path.extname(nn_model) do
 42:           ".axon" -> File.read!(nn_model) |> Axon.deserialize()
 43:           ".onnx" -> AxonOnnx.import(nn_model)
 44:         end
 45:         {_, predict_fn} = Axon.build(model, [])
 46:
 47:         # Axonの predict_fn、params、入出力のtensor仕様を状態として保持する(以降不変)
 48:         {:ok, %{model: predict_fn, params: params, path: nn_model, itempl: nn_inputs, otempl: nn_outputs}}
 49:       end
 50:
 51:       # セッション情報の生成
 52:       def session() do
 53:         # GenServerハンドラの呼び出し先__MODULE__で初期化
 54:         %AxonInterp{module: __MODULE__}
 55:       end
 56:
 57:       # infoハンドラ
 58:       def handle_call({:info}, _from, state) do
 59:         info = %{
 60:           "model"   => state.path,
 61:           "inputs"  => state.itempl,
 62:           "outputs" => state.otempl,
 63:         }
 64:         {:reply, {:ok, info}, state}
 65:      end
 66:      
 67:      # invokeハンドラ
 68:      def handle_call({:invoke, inputs}, _from, %{model: model, params: params, itempl: template}=state) do
 69:         inputs = Enum.with_index(template)
 70:           |> Enum.map(fn {{dtype, shape}, index} -> Nx.from_binary(inputs[index], dtype) |> Nx.reshape(shape) end)
 71:         # 注)まだ1入力1出力のモデルしか扱えない
 72:         input0 = Enum.at(inputs, 0)
 73:         result = model.(params, input0) |> Nx.to_binary()
 74:         {:reply, {:ok, result}, state}
 75:       end
 76:
 77:       # GenServerの停止ハンドラ
 78:       def terminate(_reason, state) do
 79:         :ok
 80:       end
 81:     end
 82:   end
 83:
 84:   doc """
 85:   バックエンド・フレームワークの名前
 86:   """
 87:   def framework() do
 88:     @framework
 89:   end
 90:
 91:   @doc """
 92:   DNNモデルの情報
 93:     mod - DNNモデルを実装したモジュール名
 94:   """
 95:   def info(mod) do
 96:     case GenServer.call(mod, {:info}, @timeout) do
 97:       {:ok, result} ->  {:ok, Map.put(result, "framework", @framework)}
 98:       any -> any
 99:     end
100:   end
101:
102:   @doc """
103:   DNNモデルを停止する
104:     mod - DNNモデルを実装したモジュール名
105:   """
106:   def stop(mod) do
107:     GenServer.stop(mod)
108:   end
109:
110:   @doc """
111:   DNNモデルの入力をセット
112:     session - セッション情報
113:     index - 入力項のインデックス
114:     bin - 入力データ(tensorをバイナリに変換したもの)
115:   """
116:   def set_input_tensor(%AxonInterp{inputs: inputs}=session, index, bin) when is_binary(bin) do
117:     # 入力データは一旦session情報に保持する
118:     %AxonInterp{session | inputs: Map.put(inputs, index, bin)}
119:   end
120:
121:   @doc """
122:   DNNモデル推論の実行
123:     session - セッション情報
124:
125:   使用例:
126:     output_0 = session()
127:       |> AxonInterp.set_input_tensor(0, input_0)
128:       |> AxonInterp.invoke()
129:       |> AxonInterp.get_output_tensor(0)
124:   """
125:   def invoke(%AxonInterp{module: mod, inputs: inputs, outputs: outputs}=session) do
126:     case GenServer.call(mod, {:invoke, inputs}, @timeout) do
127:       {:ok, result} -> %AxonInterp{session | outputs: Map.put(outputs, 0, result)}
128:       any -> any
129:     end
130:   end
131:
132:   @doc """
133:   DNNモデルの推論結果の取り出し
134:     session - セッション情報
135:     index - 出力項のインデックス
136:   """ 
138:   def get_output_tensor(%AxonInterp{outputs: outputs}, index) do
139:     # session情報に保持されている推論結果を取り出す
140:     outputs[index]
141:   end
142:
143:   ### ここから先はモデル・ファイルのチェックやダウンロード関するコードのため省略するm(_ _)m
144:   :
145: end

バックボーンの仕組みと共に押さえておきたいパーツが構造体 %AxonInterp{}です(14行目)。%AxonInterp{}の役割は、アプリからリクエストされた推論セッションの入力データと推論結果を記憶することです(116~141行目)。ご存じの通り関数型言語では、原則として手続き(関数)は内部状態を持つことができません(無記憶)。そのため、AxonInterpの様に複数の関数が共同して一つの結果を計算する場合は、処理の途中結果(関数からの返値データ、次の関数への引数データ)を何らかの方法で持ち歩く必要があります。%AxonInterp{}はそのための構造体です。

実は、AxonInterpの実体は GenServer(Actor)なので、処理の途中結果を GenServerで記憶するという設計の選択肢もありましたが、将来の並行動作を視野に入れて関数型言語のセオリーに従うことにしました[*1]

[*1]長女 TflInterpは、Tensorflow Liteの設計に倣い、Port(Actor)側に途中結果を持つインターフェイスも備えています。

その他の モデル・ファイルの拡張子によるチェックやダウンロード機能については、面白味の無い補助機能なので紹介を割愛します:wink:

3.Epilogue

AxonInterp他の *Interpシリーズは、GenServerと__using__/useマクロを利用して実装しています。use/__using__マクロは単純な機構にも関わらず、とても応用範囲の広い仕組みだと思います。2~3有名どころライブラリを調べてみると、巧みにuse/__using__マクロが用いられているのを目にします。是非ともマスターしたいテクニックですね:yum:

この記事が何かの参考になれば幸いです。
(おしまい)

3
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?