13
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

NxとNIFの間で相互通信するには

Last updated at Posted at 2021-09-23

はじめに

Nxは,NumPyやTensorFlowと同じような位置付けにあるElixirのライブラリで,Elixirの創始者のJosé Valimがチームを組んで精力的に開発を行なっています。2020年に初めて発表され,2021年9月現在もまだ開発プレビュー版の状態ですが,まもなくリリースされるとも言われています。

Nxにはバックエンドを自由に定義することができ,EXLAはGoogleのXLAを呼び出すNxのバックエンドです。何も設定しない状態で使われるバックエンドはBinary Backendというものです。独自のバックエンドを定義することもできるので,私が行なっているコード最適化の研究対象としてはもってこいです。

今回,独自のバックエンドを実装する上で,最も基礎となる,NxのBinary BackendのデータをC言語のNIFで記述したプログラム上で計算し,その結果をNxに書き戻す処理を実装しました。仮に三角関数(Nx.sin/1相当)を例題として実装しています。

ベンチマーク結果

M1 Mac mini でベンチマーク実行した結果によると,9倍程度の高速化となりました。これは,mix run samples/sin_benchmarks.exs を実行すると得られます。

% mix run -r samples/sin_benchmarks.exs
make: Nothing to be done for `all'.
Operating System: macOS
CPU Information: Apple M1
Number of Available Cores: 8
Available memory: 8 GB
Elixir 1.12.3
Erlang 24.1.2

Benchmark suite executing with the following configuration:
warmup: 2 s
time: 5 s
memory time: 0 ns
parallel: 1
inputs: none specified
Estimated total run time: 14 s

Benchmarking NIF 32...
Benchmarking Nx...

Name             ips        average  deviation         median         99th %
NIF 32      798.96 K        1.25 μs  ±1109.42%        0.99 μs        1.99 μs
Nx           89.85 K       11.13 μs    ±57.28%        9.99 μs       35.99 μs

Comparison: 
NIF 32      798.96 K
Nx           89.85 K - 8.89x slower +9.88 μs

ソースコード

NIFをコンパイルする方法については詳細は割愛しますが,mix.exsMakefile を見ていただければと思います。elixir_make を使って make を実行しています。

Cのコードはこちらです。

c_src/libnif.c
#include <stdbool.h>
#include <stdint.h>
#include <math.h>
#include <erl_nif.h>

void sin32(uint64_t size, float *in, float *out)
{
    for(uint64_t i = 0; i < size; i++) {
        out[i] = sin(in[i] * 2 * PI);
    }
}

static ERL_NIF_TERM sin32_nif(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[])
{
    if(__builtin_expect(argc != 2, false)) {
        return enif_make_badarg(env);
    }
    ErlNifUInt64 vec_size;
    if(__builtin_expect(!enif_get_uint64(env, argv[0], &vec_size), false)) {
        return enif_make_badarg(env);
    }

    ERL_NIF_TERM binary_term = argv[1];
    ErlNifBinary in_data;
    if(__builtin_expect(!enif_inspect_binary(env, binary_term, &in_data), false)) {
        return enif_make_badarg(env);
    }

    float *in = (float *)in_data.data;
    ErlNifBinary out_data;
    if(__builtin_expect(!enif_alloc_binary(vec_size * sizeof(float), &out_data), false)) {
        return enif_make_badarg(env);
    }
    float *out = (float *)out_data.data;

    sin32(vec_size, in, out);

    return enif_make_binary(env, &out_data);
}

static ErlNifFunc nif_funcs[] = 
{
    {"sin32_nif", 2, sin32_nif}
};

ERL_NIF_INIT(Elixir.NxNif, nif_funcs, NULL, NULL, NULL, NULL)

sin32_nif は,ベクタサイズ(整数)とバイナリデータを受け取って,ベクタに三角関数を適用し,バイナリデータを返します。

float *array = (float *)in_data.data;の部分でバイナリデータをキャストします。たとえば,もし符号なし16ビット整数にキャストしたいのであれば,uint16_t *array = (uint16_t *)in_data.data;のようにします。out_dataも同様にします。

その直後の sin32関数の呼出に相当する部分で,目的とする関数を呼び出します。ここでは仮に三角関数の変換を配列に対して行うsin32関数を呼び出しています。

NIFに関する記述は割愛します。

呼び出し側のElixirのコードはこちらです。

nx_nif.ex
defmodule NxNif do
  require Logger

  @moduledoc """
  Documentation for `NxNif`.
  """

  @on_load :load_nif

  def load_nif do
    nif_file = '#{Application.app_dir(:nx_nif, "priv/libnif")}'

    case :erlang.load_nif(nif_file, 0) do
      :ok -> :ok
      {:error, {:reload, _}} -> :ok
      {:error, reason} -> Logger.error("Failed to load NIF: #{inspect(reason)}")
    end
  end

  def sin32(x) when is_struct(x, Nx.Tensor) do
    if Nx.type(x) == {:f, 32} do
      x
    else
      Nx.as_type(x, {:f, 32})
    end
    |> sin32_sub()
  end

  def sin32(x) when is_number(x) do
    sin32(Nx.tensor([x]))
  end

  defp sin32_sub(t) do
    %{
      t
      | data: %{
          t.data
          | state: sin32_nif(Nx.size(t), t.data.state)
        }
    }
  end

  def sin32_nif(_size, _x), do: raise("NIF sin32_nif/2 not implemented")
end

sin32_nif関数は前述のC関数を呼び出すスタブコードで,sin32関数とsin32_sub関数がラッパーとなっています。sin32関数は,16ビット浮動小数点数のテンソルの形式に変換します。sin32_sub関数は,Nxのデータ構造からサイズとバイナリを取り出してsin32_nifを呼び出し,再びNxのデータ構造に戻しています。

CとElixirの役割分担として,Elixirの側でNxのデータ構造からサイズとバイナリを取り出し,再びNxのデータ構造に戻す役割を担っています。Cの側で担うこともできるのですが,ソースコードがかなり煩雑になるのと,実行速度が若干遅くなるという問題があるので,Elixirの側で担うことにしました。

おわりに

この記事で紹介した方法により,NxとNIFの間で相互通信することができるようになりました。また,これにより数倍程度の高速化を図れることもわかりました。今後は,この方法を活用して,Nxで新規のバックエンドを作ってみたいと思います。

13
4
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?