いよいよScalable Matrix Extension (SME) でベクトルのスカラー倍の乗算を行うプログラムを記述し,Apple BLAS版と比較してみました.
SMEシリーズ
- Apple Silicon M4はM3シリーズからScalable Matrix Extension (SME)命令などが足されている
- SME日記その1: Apple Silicon M4に搭載されたScalable Matrix Extension(SME)のベクトル長(SVL)を取得する
- SME日記その2: Apple Silicon M4にはCVTW命令は備わっていない?
- SME日記その3: Apple Silicon M4にどの命令が実装されているかをsysctl hwの実行結果とドキュメントから推測する
- SME日記その4 Streaming SVE modeでCNTWを実行してみる.
- SME日記その5 Streaming SVE modeでCNTWを実行してみる Part 2
- SME日記その6 Streaming SVE modeでsvcntw()とsvcntsw()を実行してみる
- SME日記その7 svcntw()とRDSVL命令の実行結果の関係性を考察する
- SME日記その8 __arm_new("za")について調べる
- SME日記その9 OpenBLASのSME対応状況について調べる
- SME日記その10 Streaming SVE modeでCNTWを実行してみる(再考)
- SME日記その11 OpenBLASのSSCALでSMEが使われているかを検証してみる Part.1
- SME日記その12 OpenBLASのSSCALでSMEが使われているかを検証してみる Part.2
- SME日記その13 OpenBLASのSSCALでSMEが使われているかを検証してみる Part.3
- SME日記その14 AppleBLASのSSCALでSMEが使われているかを検証してみる Part.1
- SME日記その15 AppleBLASのSGEMMでSMEが使われているかを検証してみる Part.1
- SME日記その16 Scalable Matrix Extension (SME)の研究の今後の展望についての技術的ポエム
- SME日記その17 __arm_new("za")について調べる Part.2
- SME日記その18 SMEが使えるかどうかをElixirから判定する
ソースコード
今回,破壊的更新を行いました.multiply
関数について,一旦関数オブジェクトを受け取ってから計算する方式になっています.
__arm_locally_streaming
__arm_new("za")
void multiply_factor_in_to_out(float factor, float *in, float *out, ErlNifUInt64 vec_size)
{
// Duplicate scalar across all lanes
svfloat32_t factor_vec = svdup_f32((float32_t)factor);
// Loop over the vector in chunks of vector size
// svcntw() gives the number of elements per register
for (ErlNifUInt64 i = 0; i < vec_size; i += svcntw()) {
svbool_t mask = svwhilelt_b32((uint64_t)i, (uint64_t)vec_size);
// Load the vector chunk into an SVE register
svfloat32_t vec_chunk = svld1_f32(mask, &in[i]);
// Perform element-wise multiplication
svfloat32_t result = svmul_f32_m(mask, vec_chunk, factor_vec);
// Store the result back into memory
svst1_f32(mask, &out[i], result);
}
}
-
svdup_f32
を使って,スカラー定数をベクトルに展開しています. -
svwhilelt_b32
では,マスクを作成しています.ループのほとんどでは全ての要素を有効にしますが,ループ端で適切にマスクをかけます. -
svld1_f32
は,マスクをかけながら,ベクトルをロードします. -
svmul_g32_m
は,ベクトルの乗算を行います. -
svst1_f32
は,マスクをかけながら,ベクトルをストアします.
lib/nx_sgemm.ex
defmodule NxSgemm do
@moduledoc """
Documentation for `NxSgemm`.
"""
require Logger
@on_load :load_nif
@doc false
def load_nif do
nif_file = ~c'#{Application.app_dir(:nx_sgemm, "priv/libnif")}'
case :erlang.load_nif(nif_file, 0) do
:ok -> :ok
{:error, {:reload, _}} -> :ok
{:error, reason} -> Logger.error("Failed to load NIF: #{inspect(reason)}")
end
end
@doc """
ok.
## Examples
iex> NxSgemm.ok()
:ok
"""
def ok(), do: :erlang.nif_error(:not_loaded)
@doc """
Element-wise multiplication of two tensors.
If a number is given, it is converted to a tensor.
It will broadcast tensors whenever the dimensions do not match and broadcasting is possible.
## Examples
### Multiplying scalers
iex> NxSgemm.multiply().(1, 2)
#Nx.Tensor<
s32
2
>
### Multiplying tensors and scalers
iex> NxSgemm.multiply().(Nx.tensor([1, 2, 3], names: [:data], type: :u8), 1)
#Nx.Tensor<
u8[data: 3]
[1, 2, 3]
>
iex> NxSgemm.multiply().(1, Nx.tensor([1, 2, 3], names: [:data], type: :u8))
#Nx.Tensor<
u8[data: 3]
[1, 2, 3]
>
iex> NxSgemm.multiply().(Nx.tensor([1.0, 2.0, 3.0], names: [:data], type: :f32), 2.0)
#Nx.Tensor<
f32[data: 3]
[2.0, 4.0, 6.0]
>
iex> NxSgemm.multiply().(2.0, Nx.tensor([1.0, 2.0, 3.0], names: [:data], type: :f32))
#Nx.Tensor<
f32[data: 3]
[2.0, 4.0, 6.0]
>
"""
def multiply() do
if SME.available?() and SME.use?() do
&multiply_sme/2
else
&multiply_n/2
end
end
defp multiply_n(a, b) when is_integer(a) and is_integer(b) do
Nx.tensor(a * b, type: :s32)
end
defp multiply_n(a, b) when is_float(b) do
case Nx.type(a) do
{:f, 32} ->
%{
a
| data: %{
a.data
| state: mul_nif_f32_tensor_f32_scalar(Nx.size(a), a.data.state, b)
}
}
end
end
defp multiply_n(a, b) when is_integer(b) when 0 <= b and b < 256 do
case Nx.type(a) do
{:u, 8} ->
%{
a
| data: %{
a.data
| state: mul_nif_u8_tensor_u8_scalar(Nx.size(a), a.data.state, b)
}
}
end
end
defp multiply_n(a, b) when is_number(a) do
multiply_n(b, a)
end
defp multiply_sme(a, b) when is_integer(a) and is_integer(b) do
Nx.tensor(a * b, type: :s32)
end
defp multiply_sme(a, b) when is_float(b) do
case Nx.type(a) do
{:f, 32} ->
%{
a
| data: %{
a.data
| state: mul_nif_f32_tensor_f32_scalar_sme(Nx.size(a), a.data.state, b)
}
}
end
end
defp multiply_sme(a, b) when is_integer(b) when 0 <= b and b < 256 do
case Nx.type(a) do
{:u, 8} ->
%{
a
| data: %{
a.data
| state: mul_nif_u8_tensor_u8_scalar(Nx.size(a), a.data.state, b)
}
}
end
end
defp multiply_sme(a, b) when is_number(a) do
multiply_sme(b, a)
end
defp mul_nif_f32_tensor_f32_scalar(_size, _a, _b),
do: raise("NIF mul_nif_f32_tensor_f32_scalar/3 not implemented")
defp mul_nif_f32_tensor_f32_scalar_sme(_size, _a, _b),
do: raise("NIF mul_nif_f32_tensor_f32_scalar_sme/3 not implemented")
defp mul_nif_u8_tensor_u8_scalar(_size, _a, _b),
do: raise("NIF mul_nif_u8_tensor_u8_scalar/3 not implemented")
@doc """
Returns the dot product of two tensors.
Given `a` and `b`, computes the dot product according to the following rules:
* If both `a` and `b` are scalars, it is equivalent to `a * b`.
* If `a` is a scalar and `b` is a tensor, it is equivalent to `Nx.multiply(a, b)`.
* If `a` is a tensor and `b` is a scalar, it is equivalent to `Nx.multiply(a, b)`.
* If both `a` and `b` are 1-D tensors (vectors), it is the sum of the element-wise product between `a` and `b`. The lengths of `a` and `b` must be equal.
* If both `a` and `b` are 2-D tensors (matrices), it is equivalent to matrix-multiplication.
* If either `a` or `b` is a 1-D tensor, and the other is an n-D tensor, it is the sum of the element-wise product along the last axis of `a` or `b`. The length of the 1-D tensor must match the last dimension of the n-D tensor.
* If `a` is an n-D tensor and `b` is an m-D tensor, it is the sum of the element-wise product along the last axis of `a` and the second-to-last axis of `b`. The last dimension of `a` must match the second-to-last dimension of `b`.
## Examples
iex> left = Nx.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
iex> right = Nx.tensor([[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]])
iex> Nx.dot(left, right)
#Nx.Tensor<
f32[2][2]
[
[58.0, 64.0],
[139.0, 154.0]
]
>
"""
def dot(a, b) do
case {Nx.type(a), Nx.type(b), Nx.shape(a), Nx.shape(b)} do
{{:f, 32}, {:f, 32}, {m, n}, {n, o}} ->
c = Nx.iota({m, o}, type: {:f, 32})
%{
c
| data: %{
c.data
| state: dot_nif_f32_matrix_f32_matrix(m, o, n, a.data.state, b.data.state)
}
}
end
end
defp dot_nif_f32_matrix_f32_matrix(_m, _o, _n, _a, _b),
do: raise("NIF dot_nif_f32_matrix_f32_matrix/5 not implemented")
end
nif_src/libnif.c
#include <erl_nif.h>
#include <stdbool.h>
#include <stdint.h>
#ifdef USE_OPEN_BLAS
#include <cblas.h>
#else // USE_OPEN_BLAS
#include <Accelerate/Accelerate.h>
#endif // USE_OPEN_BLAS
#ifdef SME_AVAILABLE
#include <arm_sme.h>
#endif // SME_AVAILABLE
static ERL_NIF_TERM ok(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[])
{
return enif_make_atom(env, "ok");
}
static ERL_NIF_TERM mul_nif_f32_tensor_f32_scalar(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[])
{
if (__builtin_expect(argc != 3, false)) {
return enif_make_badarg(env);
}
ErlNifUInt64 vec_size;
if (__builtin_expect(!enif_get_uint64(env, argv[0], &vec_size), false)) {
return enif_make_badarg(env);
}
ERL_NIF_TERM binary_term = argv[1];
ErlNifBinary in_data;
if (__builtin_expect(!enif_inspect_binary(env, binary_term, &in_data), false)) {
return enif_make_badarg(env);
}
ERL_NIF_TERM double_term = argv[2];
double factor;
if (__builtin_expect(!enif_get_double(env, double_term, &factor), false)) {
return enif_make_badarg(env);
}
float *in = (float *)in_data.data;
ErlNifBinary out_data;
if (__builtin_expect(!enif_alloc_binary(vec_size * sizeof(float), &out_data), false)) {
return enif_make_badarg(env);
}
float *out = (float *)out_data.data;
cblas_scopy((int)vec_size, in, 1, out, 1);
cblas_sscal((int)vec_size, (float) factor, out, 1);
return enif_make_binary(env, &out_data);
}
static ERL_NIF_TERM mul_nif_u8_tensor_u8_scalar(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[])
{
if (__builtin_expect(argc != 3, false)) {
return enif_make_badarg(env);
}
ErlNifUInt64 vec_size;
if (__builtin_expect(!enif_get_uint64(env, argv[0], &vec_size), false)) {
return enif_make_badarg(env);
}
ERL_NIF_TERM binary_term = argv[1];
ErlNifBinary in_data;
if (__builtin_expect(!enif_inspect_binary(env, binary_term, &in_data), false)) {
return enif_make_badarg(env);
}
ERL_NIF_TERM uint_term = argv[2];
unsigned int factor;
if (__builtin_expect(!enif_get_uint(env, uint_term, &factor), false)) {
return enif_make_badarg(env);
}
uint8_t *in = (uint8_t *)in_data.data;
ErlNifBinary out_data;
if (__builtin_expect(!enif_alloc_binary(vec_size * sizeof(uint8_t), &out_data), false)) {
return enif_make_badarg(env);
}
uint8_t *out = (uint8_t *)out_data.data;
for(ErlNifUInt64 i = 0; i < vec_size; i++) {
out[i] = (uint8_t) (in[i] * factor);
}
return enif_make_binary(env, &out_data);
}
static ERL_NIF_TERM dot_nif_f32_matrix_f32_matrix(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[])
{
if (__builtin_expect(argc != 5, false)) {
return enif_make_badarg(env);
}
ErlNifUInt64 m;
if (__builtin_expect(!enif_get_uint64(env, argv[0], &m), false)) {
return enif_make_badarg(env);
}
ErlNifUInt64 o;
if (__builtin_expect(!enif_get_uint64(env, argv[1], &o), false)) {
return enif_make_badarg(env);
}
ErlNifUInt64 n;
if (__builtin_expect(!enif_get_uint64(env, argv[2], &n), false)) {
return enif_make_badarg(env);
}
ERL_NIF_TERM binary_term_a = argv[3];
ErlNifBinary a_data;
if (__builtin_expect(!enif_inspect_binary(env, binary_term_a, &a_data), false)) {
return enif_make_badarg(env);
}
float *a = (float *)a_data.data;
ERL_NIF_TERM binary_term_b = argv[4];
ErlNifBinary b_data;
if (__builtin_expect(!enif_inspect_binary(env, binary_term_b, &b_data), false)) {
return enif_make_badarg(env);
}
float *b = (float *)b_data.data;
ErlNifBinary c_data;
if (__builtin_expect(!enif_alloc_binary(m * o * sizeof(float), &c_data), false)) {
return enif_make_badarg(env);
}
float *c = (float *)c_data.data;
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, m, o, n, 1.0, a, n, b, o, 0.0, c, o);
return enif_make_binary(env, &c_data);
}
#ifdef SME_AVAILABLE
__arm_locally_streaming
__arm_new("za")
void multiply_factor_in_to_out(float factor, float *in, float *out, ErlNifUInt64 vec_size)
{
// Duplicate scalar across all lanes
svfloat32_t factor_vec = svdup_f32((float32_t)factor);
// Loop over the vector in chunks of vector size
// svcntw() gives the number of elements per register
for (ErlNifUInt64 i = 0; i < vec_size; i += svcntw()) {
svbool_t mask = svwhilelt_b32((uint64_t)i, (uint64_t)vec_size);
// Load the vector chunk into an SVE register
svfloat32_t vec_chunk = svld1_f32(mask, &in[i]);
// Perform element-wise multiplication
svfloat32_t result = svmul_f32_m(mask, vec_chunk, factor_vec);
// Store the result back into memory
svst1_f32(mask, &out[i], result);
}
}
static ERL_NIF_TERM mul_nif_f32_tensor_f32_scalar_sme(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[])
{
if (__builtin_expect(argc != 3, false)) {
return enif_make_badarg(env);
}
ErlNifUInt64 vec_size;
if (__builtin_expect(!enif_get_uint64(env, argv[0], &vec_size), false)) {
return enif_make_badarg(env);
}
ERL_NIF_TERM binary_term = argv[1];
ErlNifBinary in_data;
if (__builtin_expect(!enif_inspect_binary(env, binary_term, &in_data), false)) {
return enif_make_badarg(env);
}
ERL_NIF_TERM double_term = argv[2];
double factor_d;
if (__builtin_expect(!enif_get_double(env, double_term, &factor_d), false)) {
return enif_make_badarg(env);
}
float *in = (float *)in_data.data;
ErlNifBinary out_data;
if (__builtin_expect(!enif_alloc_binary(vec_size * sizeof(float), &out_data), false)) {
return enif_make_badarg(env);
}
float *out = (float *)out_data.data;
multiply_factor_in_to_out((float)factor_d, in, out, vec_size);
return enif_make_binary(env, &out_data);
}
#endif // SME_AVAILABLE
static ErlNifFunc nif_funcs [] =
{
#ifdef SME_AVAILABLE
{"mul_nif_f32_tensor_f32_scalar_sme", 3, mul_nif_f32_tensor_f32_scalar_sme},
#endif // SME_AVAILABLE
{"ok", 0, ok},
{"mul_nif_f32_tensor_f32_scalar", 3, mul_nif_f32_tensor_f32_scalar},
{"mul_nif_u8_tensor_u8_scalar", 3, mul_nif_u8_tensor_u8_scalar},
{"dot_nif_f32_matrix_f32_matrix", 5, dot_nif_f32_matrix_f32_matrix}
};
ERL_NIF_INIT(Elixir.NxSgemm, nif_funcs, NULL, NULL, NULL, NULL)
ベンチマーク
mix.exs
efmodule NxSgemmBenchOpenblas.MixProject do
use Mix.Project
def project do
[
app: :nx_sgemm_bench_openblas,
version: "0.1.0",
elixir: "~> 1.17",
start_permanent: Mix.env() == :prod,
deps: deps()
]
end
# Run "mix help compile.app" to learn about applications.
def application do
[
extra_applications: [:logger]
]
end
# Run "mix help deps" to learn about dependencies.
defp deps do
[
# {:dep_from_hexpm, "~> 0.3.0"},
# {:dep_from_git, git: "https://github.com/elixir-lang/my_dep.git", tag: "0.1.0"}
{:nx_sgemm, github: "zacky1972/nx_sgemm", branch: "main"},
{:benchee, "~> 1.0", only: :dev}
]
end
end
benchmark.exs
Benchee.run(
%{
"Nx" => fn input -> Nx.multiply(input, 2.0) end,
"AppleBLAS" =>
{
fn {input, mul} -> mul.(input, 2.0) end,
before_scenario: fn input ->
SME.set_use(false)
{input, NxSgemm.multiply()}
end
},
"SME" =>
{
fn {input, mul} -> mul.(input, 2.0) end,
before_scenario: fn input ->
SME.set_use(true)
{input, NxSgemm.multiply()}
end
}
},
inputs: %{
"Small" => Nx.iota({1_000}) |> Nx.multiply(1.0),
"Medium" => Nx.iota({10_000}) |> Nx.multiply(1.0),
"Bigger" => Nx.iota({100_000}) |> Nx.multiply(1.0)
}
)
mix run -r benchmark.exs
実行結果
Operating System: macOS
CPU Information: Apple M4 Pro
Number of Available Cores: 14
Available memory: 64 GB
Elixir 1.18.1
Erlang 27.2
JIT enabled: true
Benchmark suite executing with the following configuration:
warmup: 2 s
time: 5 s
memory time: 0 ns
reduction time: 0 ns
parallel: 1
inputs: Bigger, Medium, Small
Estimated total run time: 1 min 3 s
Benchmarking AppleBLAS with input Bigger ...
Benchmarking AppleBLAS with input Medium ...
Benchmarking AppleBLAS with input Small ...
Benchmarking Nx with input Bigger ...
Benchmarking Nx with input Medium ...
Benchmarking Nx with input Small ...
Benchmarking SME with input Bigger ...
Benchmarking SME with input Medium ...
Benchmarking SME with input Small ...
Calculating statistics...
Formatting results...
##### With input Bigger #####
Name ips average deviation median 99th %
SME 119.72 K 8.35 μs ±62.59% 8.25 μs 8.67 μs
AppleBLAS 114.39 K 8.74 μs ±50.61% 8.71 μs 9.33 μs
Nx 0.27 K 3706.81 μs ±5.99% 3648.46 μs 4257.40 μs
Comparison:
SME 119.72 K
AppleBLAS 114.39 K - 1.05x slower +0.39 μs
Nx 0.27 K - 443.79x slower +3698.46 μs
##### With input Medium #####
Name ips average deviation median 99th %
AppleBLAS 924.71 K 1.08 μs ±1724.90% 1 μs 1.58 μs
SME 806.73 K 1.24 μs ±1400.91% 1.13 μs 1.75 μs
Nx 3.20 K 312.14 μs ±13.43% 295.25 μs 438.70 μs
Comparison:
AppleBLAS 924.71 K
SME 806.73 K - 1.15x slower +0.158 μs
Nx 3.20 K - 288.64x slower +311.06 μs
##### With input Small #####
Name ips average deviation median 99th %
AppleBLAS 3.91 M 255.62 ns ±10082.28% 208 ns 2875 ns
SME 2.05 M 487.27 ns ±3496.97% 417 ns 3708 ns
Nx 0.0339 M 29500.87 ns ±6.31% 28750 ns 33250 ns
Comparison:
AppleBLAS 3.91 M
SME 2.05 M - 1.91x slower +231.65 ns
Nx 0.0339 M - 115.41x slower +29245.24 ns
要素数100_000くらいでようやく逆転します.