9
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Nx.Serving で複数の入力を逐次・並列処理する

Last updated at Posted at 2023-06-13

はじめに

Nx.Serving を使うと、行列演算の一連処理を簡単に他ノードから実行できます

複数ノードに接続していれば自動的に分散して処理してくれますが、これを大量データに対して逐次・並列で呼び出してみます

2023/6/14 更新

Flow.map 内で分散処理がエラーになる件について、 Issue が解消されました

今回も Livebook 上で実装します

実装したノートブックはこちら

セットアップ

NxFlowKino をインストールします

2023/6/14 現在、 Flow.map 内で分散処理を実行するためには GitHub から最新版をインストールする必要があります

次のリリースが出るまではこの方法でインストールしましょう

Mix.install(
  [
    {:nx, github: "elixir-nx/nx", branch: "main", sparse: "nx", override: true},
    {:flow, "~> 1.2"},
    {:kino, "~> 0.9"}
  ]
)

サービス定義

与えられたテンソルをそのまま返す Nx.Serving を用意し、子プロセスで起動します

実行順序が分かるように IO.inspect で前処理/後処理と入出力の値を出力します

serving =
  fn opts -> Nx.Defn.jit(&(&1), opts) end
  |> Nx.Serving.new()
  |> Nx.Serving.client_preprocessing(fn input ->
    IO.inspect("client_preprocessing #{Nx.to_number(input[0])}")
    {Nx.Batch.stack([input]), :client_info}
  end)
  |> Nx.Serving.client_postprocessing(fn output, _metadata, _multi? ->
    IO.inspect("client_postprocessing #{Nx.to_number(output[[0, 0]])}")
    Nx.squeeze(output, axes: [0])
  end)

Kino.start_child({Nx.Serving, name: Echo, serving: serving})

ローカル実行

ローカル(同じノード=ノートブック)から呼び出します

ローカル単一処理

まずは単一の値を入力にします

Nx.Serving.batched_run(Echo, Nx.tensor([1]))

実行結果は以下のようになります

"client_preprocessing 1"
"client_postprocessing 1"

#Nx.Tensor<
  s64[1]
  [1]
>

ローカル逐次処理

次に、 Enum.map で逐次処理を実行します

[Nx.tensor([1]), Nx.tensor([2]), Nx.tensor([3])]
|> Enum.map(fn input ->
  Nx.Serving.batched_run(Echo, input)
end)

実行結果を見ると、逐次処理されていることが分かります

"client_preprocessing 1"
"client_postprocessing 1"
"client_preprocessing 2"
"client_postprocessing 2"
"client_preprocessing 3"
"client_postprocessing 3"

[
  #Nx.Tensor<
    s64[1]
    [1]
  >,
  #Nx.Tensor<
    s64[1]
    [2]
  >,
  #Nx.Tensor<
    s64[1]
    [3]
  >
]

ローカル並列処理

次は Flow.map を使って並列処理を実行します

stages に 3 を指定し、 3 並列で実行してみましょう

[Nx.tensor([1]), Nx.tensor([2]), Nx.tensor([3])]
|> Flow.from_enumerable(stages: 3, max_demand: 1)
|> Flow.map(fn input ->
  Nx.Serving.batched_run(Echo, input)
end)
|> Enum.to_list()

実行結果で並列処理されていることが確認できます

"client_preprocessing 1"
"client_preprocessing 2"
"client_preprocessing 3"
"client_postprocessing 1"
"client_postprocessing 2"
"client_postprocessing 3"

[
  #Nx.Tensor<
    s64[1]
    [1]
  >,
  #Nx.Tensor<
    s64[1]
    [2]
  >,
  #Nx.Tensor<
    s64[1]
    [3]
  >
]

分散実行

実行時に子プロセスを {:distributed, Echo} のように指定することで、同じノード内でも分散処理と同じように呼び出すことができます

分散単一処理

Nx.Serving.batched_run({:distributed, Echo}, Nx.tensor([1]))

特に問題なく実行できます

"client_preprocessing 1"
"client_postprocessing 1"

#Nx.Tensor<
  s64[1]
  [1]
>

分散逐次処理

[Nx.tensor([1]), Nx.tensor([2]), Nx.tensor([3])]
|> Enum.map(fn input ->
  Nx.Serving.batched_run({:distributed, Echo}, input)
end)

こちらも問題ありません

"client_preprocessing 1"
"client_postprocessing 1"
"client_preprocessing 2"
"client_postprocessing 2"
"client_preprocessing 3"
"client_postprocessing 3"

[
  #Nx.Tensor<
    s64[1]
    [1]
  >,
  #Nx.Tensor<
    s64[1]
    [2]
  >,
  #Nx.Tensor<
    s64[1]
    [3]
  >
]

並列分散処理

[Nx.tensor([1]), Nx.tensor([2]), Nx.tensor([3])]
|> Flow.from_enumerable(stages: 3, max_demand: 1)
|> Flow.map(fn input ->
  Nx.Serving.batched_run({:distributed, Echo}, input)
end)
|> Enum.to_list()

これを実行すると、エラーになります

エラーは以下の Issue で解消されました

実行結果は以下のようになります

"client_preprocessing 1"
"client_preprocessing 2"
"client_preprocessing 3"
"client_postprocessing 1"
"client_postprocessing 2"
"client_postprocessing 3"

[
  #Nx.Tensor<
    s64[1]
    [1]
  >,
  #Nx.Tensor<
    s64[1]
    [2]
  >,
  #Nx.Tensor<
    s64[1]
    [3]
  >
]

Task を利用する方法もあります

[Nx.tensor([1]), Nx.tensor([2]), Nx.tensor([3])]
|> Enum.map(fn(tensor) -> Task.async(fn ->
  Nx.Serving.batched_run({:distributed, Echo}, tensor)
end) end)
|> Enum.map(fn(task) -> Task.await(task) end)

実行結果は同じようになります

"client_preprocessing 1"
"client_preprocessing 2"
"client_preprocessing 3"
"client_postprocessing 1"
"client_postprocessing 2"
"client_postprocessing 3"
[
  #Nx.Tensor<
    s64[1]
    [1]
  >,
  #Nx.Tensor<
    s64[1]
    [2]
  >,
  #Nx.Tensor<
    s64[1]
    [3]
  >
]

まとめ

2023年6月13日現在、 Nx.Serving による分散処理を Flow で並列実行することはできませんでした

とりあえず Task を使いましょう

当初 Flow.map での並列分散がエラーになりましたは、 Issue を登録するとすぐに解消してくれました

これで簡単に並列分散行列演算が実行できます

9
3
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
9
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?