LoginSignup
5
0

Nx.Serving の前処理、後処理でノード間通信のエラーを防ぐ

Last updated at Posted at 2023-06-13

はじめに

Nx.Serving を使うと、行列演算の一連処理を簡単に他ノードから実行できます

ただし、ノード間通信をする場合、データの前処理、後処理を適切にしないとエラーが発生するケースがあります

Nx.Serving には前処理、後処理の仕組みが備えられているので、それらを使ってエラーを回避します

今回も Livebook 上で実装してみます

実装したノートブックはこちら

サーバー側

新しいノートブックを開き、行列演算を実行する側のコードを実行します

サーバー側セットアップ

NxEXLAKino をインストールします

行列演算を高速化するため、 Nx のバックエンドを EXLA に指定します

Mix.install(
  [
    {:nx, "~> 0.5"},
    {:exla, "~> 0.5"},
    {:kino, "~> 0.9"}
  ],
  config: [
    nx: [
      default_backend: EXLA.Backend
    ]
  ]
)

前処理、後処理をしない場合

今回は単に2倍する行列演算 &Nx.multiply(&1, 2) を提供します

また、 Kino.start_child で他ノードからもアクセスできるように子プロセスを開始します

default_serving =
  fn opts -> Nx.Defn.jit(&Nx.multiply(&1, 2), opts) end
  |> Nx.Serving.new(compiler: EXLA)

Kino.start_child({Nx.Serving, name: DefaultServing, serving: default_serving})

この Nx.Serving に対して入力を与える場合、データを Nx.Batch にする必要があります

この変換はいちいち手間に感じます

batch = Nx.Batch.stack([Nx.tensor([1])])

Nx.Serving.run(default_serving, batch)

実行結果は以下のようになります

#Nx.Tensor<
  s64[1][1]
  EXLA.Backend<host:0, 0.511354218.2511208458.206349>
  [
    [2]
  ]
>

つまり、 Nx.Batch.stack([<バッチ1>])Nx.tensor([<バッチ1の演算結果>]) になっています

少し直観的ではないですね

子プロセスを呼び出す場合も同様です

Nx.Serving.batched_run(DefaultServing, batch)

実行結果は同じようになります

#Nx.Tensor<
  s64[1][1]
  EXLA.Backend<host:0, 0.511354218.2511208458.206349>
  [
    [2]
  ]
>

前処理、後処理をする場合

前処理、後処理を設定します

IO.inspect で「どの関数なのか」と「どのノードで実行したのか」を表示します

pre_post_serving =
  fn opts -> Nx.Defn.jit(&Nx.multiply(&1, 2), opts) end
  |> Nx.Serving.new(compiler: EXLA)
  |> Nx.Serving.client_preprocessing(fn input ->
    # 前処理
    IO.inspect("client_preprocessing")
    IO.inspect(Node.self())
    # テンソルで与えられた入力をバッチに変換する
    {Nx.Batch.stack([input]), :client_info}
  end)
  |> Nx.Serving.client_postprocessing(fn output, _metadata, _multi? ->
    # 後処理
    IO.inspect("client_postprocessing")
    IO.inspect(Node.self())
    # 出力テンソルの次元を減らす
    Nx.squeeze(output, axes: [0])
  end)
  |> Nx.Serving.distributed_postprocessing(fn output ->
    # 分散している場合の後処理
    IO.inspect("distributed_postprocessing")
    IO.inspect(Node.self())
    # バイナリバックエンドに変換する
    Nx.backend_transfer(output, Nx.BinaryBackend)
  end)

Kino.start_child({Nx.Serving, name: PrePostServing, serving: pre_post_serving})

それぞれ、以下の関数で処理を追加できます

現在のノードを確認します

Node.self()

結果は以下のようになりました

:"dog6ujac-livebook_server@eb526776ce9e"

直接実行

前処理内でバッチ化するため、テンソルを直接入力にすることができます

Nx.Serving.run(pre_post_serving, Nx.tensor([1]))

実行結果は以下のようになります

前処理、後処理に仕込んだ IO.inspect が動いています

(今回はローカル実行なので、 distributed_postprocessing は動いていません)

また、次元が減って入力のテンソルと同じ形になっています

"client_preprocessing"
:"dog6ujac-livebook_server@eb526776ce9e"
"client_postprocessing"
:"dog6ujac-livebook_server@eb526776ce9e"

#Nx.Tensor<
  s64[1]
  EXLA.Backend<host:0, 0.511354218.2511208458.206383>
  [2]
>

ローカル子プロセス実行

ローカル(同じノード内)で子プロセスを呼び出します

Nx.Serving.batched_run(PrePostServing, Nx.tensor([1]))

ローカル内にあるため、結果は直接実行の場合と同じになります

"client_preprocessing"
:"dog6ujac-livebook_server@eb526776ce9e"
"client_postprocessing"
:"dog6ujac-livebook_server@eb526776ce9e"

#Nx.Tensor<
  s64[1]
  EXLA.Backend<host:0, 0.511354218.2511208458.206394>
  [2]
>

擬似分散実行

Nx.Serving.batched_run の第1引数を {:distributed, <子プロセス名>} のようにすると、別ノードから実行したときと同じ挙動をします(分散実行を明示しています)

distributed_preprocessing = fn input ->
  # 分散している場合の前処理
  IO.inspect("distributed_preprocessing")
  IO.inspect(Node.self())
  # バイナリバックエンドに変換する
  Nx.backend_transfer(input, Nx.BinaryBackend)
end

Nx.Serving.batched_run({:distributed, PrePostServing}, Nx.tensor([1]), distributed_preprocessing)

実行すると、以下のようになります

"distributed_preprocessing"
:"dog6ujac-livebook_server@eb526776ce9e"
"client_preprocessing"
:"dog6ujac-livebook_server@eb526776ce9e"
"client_postprocessing"
:"dog6ujac-livebook_server@eb526776ce9e"
"distributed_postprocessing"
:"dog6ujac-livebook_server@eb526776ce9e"

#Nx.Tensor<
  s64[1]
  [2]
>

実行される順番が以下のようになっていることが分かります

  • distributed_preprocessing
  • client_preprocessing
  • client_postprocessing
  • distributed_postprocessing

クッキーの取得

次に別ノードから接続するため、クッキーを表示しておきます

(ノード名は先程取得済です)

Node.get_cookie()

クライアント側

別のノートブックを開き、以下のコードを実行します

クライアント側セットアップ

サーバー側と同じセットアップを実行します

Mix.install(
  [
    {:nx, "~> 0.5"},
    {:exla, "~> 0.5"},
    {:kino, "~> 0.9"}
  ],
  config: [
    nx: [
      default_backend: EXLA.Backend
    ]
  ]
)

サーバーへの接続

ノード名とクッキーのテキスト入力を作ります

node_name_input = Kino.Input.text("SERVER_NODE_NAME")
cookie_input = Kino.Input.text("SERVER_COOKIE")

[node_name_input, cookie_input]
|> Kino.Layout.grid(columns: 2)

表示されたテキスト入力にサーバー側のノード名、クッキーを入力します

スクリーンショット 2023-06-13 11.24.40.png

サーバーに接続します

node_name =
  node_name_input
  |> Kino.Input.read()
  |> String.to_atom()

cookie =
  cookie_input
  |> Kino.Input.read()
  |> String.to_atom()

Node.set_cookie(node_name, cookie)

Node.connect(node_name)

true が返ってくれば接続できています

分散実行

では、実際に別ノードから実行してみましょう

まず、前処理、後処理のない子プロセスを呼んでみます

batch = Nx.Batch.stack([Nx.tensor([1])])

Nx.Serving.batched_run(DefaultServing, batch)

すると、以下のようなエラーが発生します

(セル自体は正常終了しますが、エラーメッセージが表示され、演算処理は実行されません)

#Inspect.Error<
  got RuntimeError with message:

      """
      Unable to get buffer.
      """

  while inspecting:

      %{
        __struct__: Nx.Tensor,
        data: %EXLA.Backend{
          buffer: %EXLA.DeviceBuffer{
            ref: #Reference<17018.511354218.2511208458.206405>,
            client_name: :host,
            device_id: 0,
            shape: %EXLA.Shape{
              ref: #Reference<17018.511354218.2511470593.75873>,
              dims: {1, 1},
              dtype: {:s, 64}
            }
          }
        },
        names: [nil, nil],
        shape: {1, 1},
        type: {:s, 64}
      }

  Stacktrace:

    (exla 0.5.3) lib/exla/device_buffer.ex:55: EXLA.DeviceBuffer.unwrap!/1
    (exla 0.5.3) lib/exla/backend.ex:133: EXLA.Backend.inspect/2
    (nx 0.5.3) lib/nx/tensor.ex:165: Inspect.Nx.Tensor.inspect/2
    (elixir 1.14.2) lib/inspect/algebra.ex:341: Inspect.Algebra.to_doc/2
    (elixir 1.14.2) lib/kernel.ex:2254: Kernel.inspect/2
    (kino 0.9.4) lib/kino/output.ex:468: Kino.Output.inspect/2
    lib/livebook/runtime/evaluator/default_formatter.ex:41: Livebook.Runtime.Evaluator.DefaultFormatter.to_output/1
    lib/livebook/runtime/evaluator.ex:464: Livebook.Runtime.Evaluator.continue_do_evaluate_code/5

>

これは、 Nx のドキュメントにも書いてあることです

batched_run/3 receives an optional distributed_preprocessing callback as third argument for preprocessing the input for distributed requests. When using libraries like EXLA or Torchx, the tensor is often allocated in memory inside a third-party library so it is necessary to either transfer or copy the tensor to the binary backend before sending it to another node. This can be done by passing either Nx.backend_transfer/1 or Nx.backend_copy/1 as third argument:

Nx.Serving.batched_run(MyDistributedServing, input, &Nx.backend_copy/1)

和訳(DeepL)

batched_run/3 は、分散リクエストのための入力の前処理のために、オプションの distributed_preprocessing コールバックを第 3 引数として受け取ります。 EXLA や Torchx のようなライブラリを使用する場合、テンソルはサードパーティライブラリ内のメモリに割り当てられることが多いので、他のノードに送信する前にテンソルをバイナリバックエンドに転送またはコピーすることが必要です。これは、第 3 引数に Nx.backend_transfer/1 または Nx.backend_copy/1 を渡すことで実行できます。

というわけなので、前処理、後処理付の子プロセスを呼び出します

distributed_preprocessing = fn input ->
  # 分散している場合の前処理
  IO.inspect("distributed_preprocessing")
  IO.inspect(Node.self())
  # バイナリバックエンドに変換する
  Nx.backend_transfer(input, Nx.BinaryBackend)
end

Nx.Serving.batched_run(PrePostServing, Nx.tensor([1]), distributed_preprocessing)

実行結果は以下のようになります

"distributed_preprocessing"
:"hscwsmdh-livebook_server@eb526776ce9e"
"client_preprocessing"
:"dog6ujac-livebook_server@eb526776ce9e"
"client_postprocessing"
:"dog6ujac-livebook_server@eb526776ce9e"
"distributed_postprocessing"
:"dog6ujac-livebook_server@eb526776ce9e"

#Nx.Tensor<
  s64[1]
  [2]
>

distributed_preprocessing だけが自分のノードで実行され、後はサーバー側で実行されました

まとめ

EXLA などのバックエンドを使っている場合、分散処理をするには前処理、後処理が必須になります

また、前処理、後処理で入出力の形式を変更することができます

(Bumblebee は内部でこれを利用しています)

5
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
0