2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Scalable Matrix Extension (SME)Advent Calendar 2024

Day 20

SME日記その19 SMEでベクトルのスカラー倍を記述してみる

Posted at

いよいよScalable Matrix Extension (SME) でベクトルのスカラー倍の乗算を行うプログラムを記述し,Apple BLAS版と比較してみました.

SMEシリーズ

ソースコード

今回,破壊的更新を行いました.multiply関数について,一旦関数オブジェクトを受け取ってから計算する方式になっています.

__arm_locally_streaming
__arm_new("za")
void multiply_factor_in_to_out(float factor, float *in, float *out, ErlNifUInt64 vec_size)
{
    // Duplicate scalar across all lanes
    svfloat32_t factor_vec = svdup_f32((float32_t)factor);

    // Loop over the vector in chunks of vector size
    // svcntw() gives the number of elements per register
    for (ErlNifUInt64 i = 0; i < vec_size; i += svcntw()) {
        svbool_t mask = svwhilelt_b32((uint64_t)i, (uint64_t)vec_size);

        // Load the vector chunk into an SVE register
        svfloat32_t vec_chunk = svld1_f32(mask, &in[i]);

        // Perform element-wise multiplication
        svfloat32_t result = svmul_f32_m(mask, vec_chunk, factor_vec);
        
        // Store the result back into memory
        svst1_f32(mask, &out[i], result);   
    }
}
  • svdup_f32を使って,スカラー定数をベクトルに展開しています.
  • svwhilelt_b32では,マスクを作成しています.ループのほとんどでは全ての要素を有効にしますが,ループ端で適切にマスクをかけます.
  • svld1_f32は,マスクをかけながら,ベクトルをロードします.
  • svmul_g32_mは,ベクトルの乗算を行います.
  • svst1_f32は,マスクをかけながら,ベクトルをストアします.
lib/nx_sgemm.ex
defmodule NxSgemm do
  @moduledoc """
  Documentation for `NxSgemm`.
  """
  require Logger

  @on_load :load_nif

  @doc false
  def load_nif do
    nif_file = ~c'#{Application.app_dir(:nx_sgemm, "priv/libnif")}'

    case :erlang.load_nif(nif_file, 0) do
      :ok -> :ok
      {:error, {:reload, _}} -> :ok
      {:error, reason} -> Logger.error("Failed to load NIF: #{inspect(reason)}")
    end
  end

  @doc """
  ok.

  ## Examples

      iex> NxSgemm.ok()
      :ok

  """
  def ok(), do: :erlang.nif_error(:not_loaded)

  @doc """
  Element-wise multiplication of two tensors.

  If a number is given, it is converted to a tensor.

  It will broadcast tensors whenever the dimensions do not match and broadcasting is possible.

  ## Examples

  ### Multiplying scalers

      iex> NxSgemm.multiply().(1, 2)
      #Nx.Tensor<
        s32
        2
      >

  ### Multiplying tensors and scalers

      iex> NxSgemm.multiply().(Nx.tensor([1, 2, 3], names: [:data], type: :u8), 1)
      #Nx.Tensor<
        u8[data: 3]
        [1, 2, 3]
      >

      iex> NxSgemm.multiply().(1, Nx.tensor([1, 2, 3], names: [:data], type: :u8))
      #Nx.Tensor<
        u8[data: 3]
        [1, 2, 3]
      >

      iex> NxSgemm.multiply().(Nx.tensor([1.0, 2.0, 3.0], names: [:data], type: :f32), 2.0)
      #Nx.Tensor<
        f32[data: 3]
        [2.0, 4.0, 6.0]
      >

      iex> NxSgemm.multiply().(2.0, Nx.tensor([1.0, 2.0, 3.0], names: [:data], type: :f32))
      #Nx.Tensor<
        f32[data: 3]
        [2.0, 4.0, 6.0]
      >
  """
  def multiply() do
    if SME.available?() and SME.use?() do
      &multiply_sme/2
    else
      &multiply_n/2
    end
  end

  defp multiply_n(a, b) when is_integer(a) and is_integer(b) do
    Nx.tensor(a * b, type: :s32)
  end

  defp multiply_n(a, b) when is_float(b) do
    case Nx.type(a) do
      {:f, 32} ->
        %{
          a
          | data: %{
              a.data
              | state: mul_nif_f32_tensor_f32_scalar(Nx.size(a), a.data.state, b)
            }
        }
    end
  end

  defp multiply_n(a, b) when is_integer(b) when 0 <= b and b < 256 do
    case Nx.type(a) do
      {:u, 8} ->
        %{
          a
          | data: %{
              a.data
              | state: mul_nif_u8_tensor_u8_scalar(Nx.size(a), a.data.state, b)
            }
        }
    end
  end

  defp multiply_n(a, b) when is_number(a) do
    multiply_n(b, a)
  end

  defp multiply_sme(a, b) when is_integer(a) and is_integer(b) do
    Nx.tensor(a * b, type: :s32)
  end

  defp multiply_sme(a, b) when is_float(b) do
    case Nx.type(a) do
      {:f, 32} ->
        %{
          a
          | data: %{
              a.data
              | state: mul_nif_f32_tensor_f32_scalar_sme(Nx.size(a), a.data.state, b)
            }
        }
    end
  end

  defp multiply_sme(a, b) when is_integer(b) when 0 <= b and b < 256 do
    case Nx.type(a) do
      {:u, 8} ->
        %{
          a
          | data: %{
              a.data
              | state: mul_nif_u8_tensor_u8_scalar(Nx.size(a), a.data.state, b)
            }
        }
    end
  end

  defp multiply_sme(a, b) when is_number(a) do
    multiply_sme(b, a)
  end

  defp mul_nif_f32_tensor_f32_scalar(_size, _a, _b),
    do: raise("NIF mul_nif_f32_tensor_f32_scalar/3 not implemented")

  defp mul_nif_f32_tensor_f32_scalar_sme(_size, _a, _b),
    do: raise("NIF mul_nif_f32_tensor_f32_scalar_sme/3 not implemented")

  defp mul_nif_u8_tensor_u8_scalar(_size, _a, _b),
    do: raise("NIF mul_nif_u8_tensor_u8_scalar/3 not implemented")

  @doc """
  Returns the dot product of two tensors.

  Given `a` and `b`, computes the dot product according to the following rules:

  * If both `a` and `b` are scalars, it is equivalent to `a * b`.
  * If `a` is a scalar and `b` is a tensor, it is equivalent to `Nx.multiply(a, b)`.
  * If `a` is a tensor and `b` is a scalar, it is equivalent to `Nx.multiply(a, b)`.
  * If both `a` and `b` are 1-D tensors (vectors), it is the sum of the element-wise product between `a` and `b`. The lengths of `a` and `b` must be equal.
  * If both `a` and `b` are 2-D tensors (matrices), it is equivalent to matrix-multiplication.
  * If either `a` or `b` is a 1-D tensor, and the other is an n-D tensor, it is the sum of the element-wise product along the last axis of `a` or `b`. The length of the 1-D tensor must match the last dimension of the n-D tensor.
  * If `a` is an n-D tensor and `b` is an m-D tensor, it is the sum of the element-wise product along the last axis of `a` and the second-to-last axis of `b`. The last dimension of `a` must match the second-to-last dimension of `b`.

  ## Examples

      iex> left = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
      iex> right = Nx.tensor([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]])
      iex> Nx.dot(left, right)
      #Nx.Tensor<
        f32[2][2]
        [
          [58.0, 64.0],
          [139.0, 154.0]
        ]
      >
  """
  def dot(a, b) do
    case {Nx.type(a), Nx.type(b), Nx.shape(a), Nx.shape(b)} do
      {{:f, 32}, {:f, 32}, {m, n}, {n, o}} ->
        c = Nx.iota({m, o}, type: {:f, 32})

        %{
          c
          | data: %{
              c.data
              | state: dot_nif_f32_matrix_f32_matrix(m, o, n, a.data.state, b.data.state)
            }
        }
    end
  end

  defp dot_nif_f32_matrix_f32_matrix(_m, _o, _n, _a, _b),
    do: raise("NIF dot_nif_f32_matrix_f32_matrix/5 not implemented")
end
nif_src/libnif.c
#include <erl_nif.h>
#include <stdbool.h>
#include <stdint.h>

#ifdef USE_OPEN_BLAS
#include <cblas.h>
#else // USE_OPEN_BLAS
#include <Accelerate/Accelerate.h>
#endif // USE_OPEN_BLAS

#ifdef SME_AVAILABLE
#include <arm_sme.h>
#endif // SME_AVAILABLE

static ERL_NIF_TERM ok(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[])
{
    return enif_make_atom(env, "ok");
}

static ERL_NIF_TERM mul_nif_f32_tensor_f32_scalar(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[])
{
    if (__builtin_expect(argc != 3, false)) {
        return enif_make_badarg(env);
    }

    ErlNifUInt64 vec_size;
    if (__builtin_expect(!enif_get_uint64(env, argv[0], &vec_size), false)) {
        return enif_make_badarg(env);
    }

    ERL_NIF_TERM binary_term = argv[1];
    ErlNifBinary in_data;
    if (__builtin_expect(!enif_inspect_binary(env, binary_term, &in_data), false)) {
        return enif_make_badarg(env);
    }

    ERL_NIF_TERM double_term = argv[2];
    double factor;
    if (__builtin_expect(!enif_get_double(env, double_term, &factor), false)) {
        return enif_make_badarg(env);
    }

    float *in = (float *)in_data.data;
    ErlNifBinary out_data;
    if (__builtin_expect(!enif_alloc_binary(vec_size * sizeof(float), &out_data), false)) {
        return enif_make_badarg(env);
    }

    float *out = (float *)out_data.data;

    cblas_scopy((int)vec_size, in, 1, out, 1);
    cblas_sscal((int)vec_size, (float) factor, out, 1);

    return enif_make_binary(env, &out_data);
}

static ERL_NIF_TERM mul_nif_u8_tensor_u8_scalar(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[])
{
    if (__builtin_expect(argc != 3, false)) {
        return enif_make_badarg(env);
    }

    ErlNifUInt64 vec_size;
    if (__builtin_expect(!enif_get_uint64(env, argv[0], &vec_size), false)) {
        return enif_make_badarg(env);
    }

    ERL_NIF_TERM binary_term = argv[1];
    ErlNifBinary in_data;
    if (__builtin_expect(!enif_inspect_binary(env, binary_term, &in_data), false)) {
        return enif_make_badarg(env);
    }

    ERL_NIF_TERM uint_term = argv[2];
    unsigned int factor;
    if (__builtin_expect(!enif_get_uint(env, uint_term, &factor), false)) {
        return enif_make_badarg(env);
    }

    uint8_t *in = (uint8_t *)in_data.data;
    ErlNifBinary out_data;
    if (__builtin_expect(!enif_alloc_binary(vec_size * sizeof(uint8_t), &out_data), false)) {
        return enif_make_badarg(env);
    }

    uint8_t *out = (uint8_t *)out_data.data;

    for(ErlNifUInt64 i = 0; i < vec_size; i++) {
        out[i] = (uint8_t) (in[i] * factor); 
    }

    return enif_make_binary(env, &out_data);
}

static ERL_NIF_TERM dot_nif_f32_matrix_f32_matrix(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[])
{
    if (__builtin_expect(argc != 5, false)) {
        return enif_make_badarg(env);
    }

    ErlNifUInt64 m;
    if (__builtin_expect(!enif_get_uint64(env, argv[0], &m), false)) {
        return enif_make_badarg(env);
    }

    ErlNifUInt64 o;
    if (__builtin_expect(!enif_get_uint64(env, argv[1], &o), false)) {
        return enif_make_badarg(env);
    }

    ErlNifUInt64 n;
    if (__builtin_expect(!enif_get_uint64(env, argv[2], &n), false)) {
        return enif_make_badarg(env);
    }

    ERL_NIF_TERM binary_term_a = argv[3];
    ErlNifBinary a_data;
    if (__builtin_expect(!enif_inspect_binary(env, binary_term_a, &a_data), false)) {
        return enif_make_badarg(env);
    }
    float *a = (float *)a_data.data;

    ERL_NIF_TERM binary_term_b = argv[4];
    ErlNifBinary b_data;
    if (__builtin_expect(!enif_inspect_binary(env, binary_term_b, &b_data), false)) {
        return enif_make_badarg(env);
    }
    float *b = (float *)b_data.data;

    ErlNifBinary c_data;
    if (__builtin_expect(!enif_alloc_binary(m * o * sizeof(float), &c_data), false)) {
        return enif_make_badarg(env);
    }
    float *c = (float *)c_data.data;

    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, o, n, 1.0, a, n, b, o, 0.0, c, o);

    return enif_make_binary(env, &c_data);
}

#ifdef SME_AVAILABLE
__arm_locally_streaming
__arm_new("za")
void multiply_factor_in_to_out(float factor, float *in, float *out, ErlNifUInt64 vec_size)
{
    // Duplicate scalar across all lanes
    svfloat32_t factor_vec = svdup_f32((float32_t)factor);

    // Loop over the vector in chunks of vector size
    // svcntw() gives the number of elements per register
    for (ErlNifUInt64 i = 0; i < vec_size; i += svcntw()) {
        svbool_t mask = svwhilelt_b32((uint64_t)i, (uint64_t)vec_size);

        // Load the vector chunk into an SVE register
        svfloat32_t vec_chunk = svld1_f32(mask, &in[i]);

        // Perform element-wise multiplication
        svfloat32_t result = svmul_f32_m(mask, vec_chunk, factor_vec);
        
        // Store the result back into memory
        svst1_f32(mask, &out[i], result);   
    }
}

static ERL_NIF_TERM mul_nif_f32_tensor_f32_scalar_sme(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[])
{
    if (__builtin_expect(argc != 3, false)) {
        return enif_make_badarg(env);
    }

    ErlNifUInt64 vec_size;
    if (__builtin_expect(!enif_get_uint64(env, argv[0], &vec_size), false)) {
        return enif_make_badarg(env);
    }

    ERL_NIF_TERM binary_term = argv[1];
    ErlNifBinary in_data;
    if (__builtin_expect(!enif_inspect_binary(env, binary_term, &in_data), false)) {
        return enif_make_badarg(env);
    }

    ERL_NIF_TERM double_term = argv[2];
    double factor_d;
    if (__builtin_expect(!enif_get_double(env, double_term, &factor_d), false)) {
        return enif_make_badarg(env);
    }

    float *in = (float *)in_data.data;
    ErlNifBinary out_data;
    if (__builtin_expect(!enif_alloc_binary(vec_size * sizeof(float), &out_data), false)) {
        return enif_make_badarg(env);
    }

    float *out = (float *)out_data.data;

    multiply_factor_in_to_out((float)factor_d, in, out, vec_size);

    return enif_make_binary(env, &out_data);    
}
#endif // SME_AVAILABLE

static ErlNifFunc nif_funcs [] =
{
#ifdef SME_AVAILABLE
    {"mul_nif_f32_tensor_f32_scalar_sme", 3, mul_nif_f32_tensor_f32_scalar_sme},
#endif // SME_AVAILABLE
    {"ok", 0, ok},
    {"mul_nif_f32_tensor_f32_scalar", 3, mul_nif_f32_tensor_f32_scalar},
    {"mul_nif_u8_tensor_u8_scalar", 3, mul_nif_u8_tensor_u8_scalar},
    {"dot_nif_f32_matrix_f32_matrix", 5, dot_nif_f32_matrix_f32_matrix}
};

ERL_NIF_INIT(Elixir.NxSgemm, nif_funcs, NULL, NULL, NULL, NULL)

ベンチマーク

mix.exs
efmodule NxSgemmBenchOpenblas.MixProject do
  use Mix.Project

  def project do
    [
      app: :nx_sgemm_bench_openblas,
      version: "0.1.0",
      elixir: "~> 1.17",
      start_permanent: Mix.env() == :prod,
      deps: deps()
    ]
  end

  # Run "mix help compile.app" to learn about applications.
  def application do
    [
      extra_applications: [:logger]
    ]
  end

  # Run "mix help deps" to learn about dependencies.
  defp deps do
    [
      # {:dep_from_hexpm, "~> 0.3.0"},
      # {:dep_from_git, git: "https://github.com/elixir-lang/my_dep.git", tag: "0.1.0"}
      {:nx_sgemm, github: "zacky1972/nx_sgemm", branch: "main"},
      {:benchee, "~> 1.0", only: :dev}
    ]
  end
end
benchmark.exs
Benchee.run(
  %{
    "Nx" => fn input -> Nx.multiply(input, 2.0) end,
    "AppleBLAS" =>
      {
        fn {input, mul} -> mul.(input, 2.0) end,
        before_scenario: fn input ->
          SME.set_use(false)
          {input, NxSgemm.multiply()}
        end
      },
    "SME" =>
      {
        fn {input, mul} -> mul.(input, 2.0) end,
        before_scenario: fn input ->
          SME.set_use(true)
          {input, NxSgemm.multiply()}
        end
      }
  },
  inputs: %{
    "Small" => Nx.iota({1_000}) |> Nx.multiply(1.0),
    "Medium" => Nx.iota({10_000}) |> Nx.multiply(1.0),
    "Bigger" => Nx.iota({100_000}) |> Nx.multiply(1.0)
  }
)
mix run -r benchmark.exs

実行結果

Operating System: macOS
CPU Information: Apple M4 Pro
Number of Available Cores: 14
Available memory: 64 GB
Elixir 1.18.1
Erlang 27.2
JIT enabled: true

Benchmark suite executing with the following configuration:
warmup: 2 s
time: 5 s
memory time: 0 ns
reduction time: 0 ns
parallel: 1
inputs: Bigger, Medium, Small
Estimated total run time: 1 min 3 s

Benchmarking AppleBLAS with input Bigger ...
Benchmarking AppleBLAS with input Medium ...
Benchmarking AppleBLAS with input Small ...
Benchmarking Nx with input Bigger ...
Benchmarking Nx with input Medium ...
Benchmarking Nx with input Small ...
Benchmarking SME with input Bigger ...
Benchmarking SME with input Medium ...
Benchmarking SME with input Small ...
Calculating statistics...
Formatting results...

##### With input Bigger #####
Name                ips        average  deviation         median         99th %
SME            119.72 K        8.35 μs    ±62.59%        8.25 μs        8.67 μs
AppleBLAS      114.39 K        8.74 μs    ±50.61%        8.71 μs        9.33 μs
Nx               0.27 K     3706.81 μs     ±5.99%     3648.46 μs     4257.40 μs

Comparison: 
SME            119.72 K
AppleBLAS      114.39 K - 1.05x slower +0.39 μs
Nx               0.27 K - 443.79x slower +3698.46 μs

##### With input Medium #####
Name                ips        average  deviation         median         99th %
AppleBLAS      924.71 K        1.08 μs  ±1724.90%           1 μs        1.58 μs
SME            806.73 K        1.24 μs  ±1400.91%        1.13 μs        1.75 μs
Nx               3.20 K      312.14 μs    ±13.43%      295.25 μs      438.70 μs

Comparison: 
AppleBLAS      924.71 K
SME            806.73 K - 1.15x slower +0.158 μs
Nx               3.20 K - 288.64x slower +311.06 μs

##### With input Small #####
Name                ips        average  deviation         median         99th %
AppleBLAS        3.91 M      255.62 ns ±10082.28%         208 ns        2875 ns
SME              2.05 M      487.27 ns  ±3496.97%         417 ns        3708 ns
Nx             0.0339 M    29500.87 ns     ±6.31%       28750 ns       33250 ns

Comparison: 
AppleBLAS        3.91 M
SME              2.05 M - 1.91x slower +231.65 ns
Nx             0.0339 M - 115.41x slower +29245.24 ns

要素数100_000くらいでようやく逆転します.

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?