4
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Scalable Matrix Extension (SME)Advent Calendar 2024

Day 12

SME日記その11 OpenBLASのSSCALでSMEが使われているかを検証してみる Part.1

Last updated at Posted at 2024-12-11

OpenBLASでがSME使われている可能性に思い至ったので,手始めにSSCALでSMEが使われているかを検証することを試みました.

SMEシリーズ

ソースコード

ExTaskを使っています.

nif_src/libnif.c
#include <erl_nif.h>
#include <stdbool.h>
#include <stdint.h>
#ifdef USE_OPEN_BLAS
#include <cblas.h>
#else // USE_OPEN_BLAS
#include <Accelerate/Accelerate.h>
#endif // USE_OPEN_BLAS

static ERL_NIF_TERM ok(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[])
{
    return enif_make_atom(env, "ok");
}

static ERL_NIF_TERM mul_nif_f32_tensor_f32_scalar(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[])
{
    if (__builtin_expect(argc != 3, false)) {
        return enif_make_badarg(env);
    }

    ErlNifUInt64 vec_size;
    if (__builtin_expect(!enif_get_uint64(env, argv[0], &vec_size), false)) {
        return enif_make_badarg(env);
    }

    ERL_NIF_TERM binary_term = argv[1];
    ErlNifBinary in_data;
    if (__builtin_expect(!enif_inspect_binary(env, binary_term, &in_data), false)) {
        return enif_make_badarg(env);
    }

    ERL_NIF_TERM double_term = argv[2];
    double factor;
    if (__builtin_expect(!enif_get_double(env, double_term, &factor), false)) {
        return enif_make_badarg(env);
    }

    float *in = (float *)in_data.data;
    ErlNifBinary out_data;
    if (__builtin_expect(!enif_alloc_binary(vec_size * sizeof(float), &out_data), false)) {
        return enif_make_badarg(env);
    }

    float *out = (float *)out_data.data;

    cblas_scopy((int)vec_size, in, 1, out, 1);
    cblas_sscal((int)vec_size, (float) factor, out, 1);

    return enif_make_binary(env, &out_data);
}

static ERL_NIF_TERM mul_nif_u8_tensor_u8_scalar(ErlNifEnv *env, int argc, const ERL_NIF_TERM argv[])
{
    if (__builtin_expect(argc != 3, false)) {
        return enif_make_badarg(env);
    }

    ErlNifUInt64 vec_size;
    if (__builtin_expect(!enif_get_uint64(env, argv[0], &vec_size), false)) {
        return enif_make_badarg(env);
    }

    ERL_NIF_TERM binary_term = argv[1];
    ErlNifBinary in_data;
    if (__builtin_expect(!enif_inspect_binary(env, binary_term, &in_data), false)) {
        return enif_make_badarg(env);
    }

    ERL_NIF_TERM uint_term = argv[2];
    unsigned int factor;
    if (__builtin_expect(!enif_get_uint(env, uint_term, &factor), false)) {
        return enif_make_badarg(env);
    }

    uint8_t *in = (uint8_t *)in_data.data;
    ErlNifBinary out_data;
    if (__builtin_expect(!enif_alloc_binary(vec_size * sizeof(uint8_t), &out_data), false)) {
        return enif_make_badarg(env);
    }

    uint8_t *out = (uint8_t *)out_data.data;

    for(ErlNifUInt64 i = 0; i < vec_size; i++) {
        out[i] = (uint8_t) (in[i] * factor); 
    }

    return enif_make_binary(env, &out_data);
}

static ErlNifFunc nif_funcs [] =
{
    {"ok", 0, ok},
    {"mul_nif_f32_tensor_f32_scalar", 3, mul_nif_f32_tensor_f32_scalar},
    {"mul_nif_u8_tensor_u8_scalar", 3, mul_nif_u8_tensor_u8_scalar}
};

ERL_NIF_INIT(Elixir.NxSgemm, nif_funcs, NULL, NULL, NULL, NULL)
lib/nx_sgemm.ex
defmodule NxSgemm do
  @moduledoc """
  Documentation for `NxSgemm`.
  """
  require Logger

  @on_load :load_nif

  @doc false
  def load_nif do
    nif_file = ~c'#{Application.app_dir(:nx_sgemm, "priv/libnif")}'

    case :erlang.load_nif(nif_file, 0) do
      :ok -> :ok
      {:error, {:reload, _}} -> :ok
      {:error, reason} -> Logger.error("Failed to load NIF: #{inspect(reason)}")
    end
  end

  @doc """
  ok.

  ## Examples

      iex> NxSgemm.ok()
      :ok

  """
  def ok(), do: :erlang.nif_error(:not_loaded)

  @doc """
  Element-wise multiplication of two tensors.

  If a number is given, it is converted to a tensor.

  It will broadcast tensors whenever the dimensions do not match and broadcasting is possible.

  ## Examples

  ### Multiplying scalers

      iex> NxSgemm.multiply(1, 2)
      #Nx.Tensor<
        s32
        2
      >

  ### Multiplying tensors and scalers

      iex> NxSgemm.multiply(Nx.tensor([1, 2, 3], names: [:data], type: :u8), 1)
      #Nx.Tensor<
        u8[data: 3]
        [1, 2, 3]
      >

      iex> NxSgemm.multiply(1, Nx.tensor([1, 2, 3], names: [:data], type: :u8))
      #Nx.Tensor<
        u8[data: 3]
        [1, 2, 3]
      >

      iex> NxSgemm.multiply(Nx.tensor([1.0, 2.0, 3.0], names: [:data], type: :f32), 2.0)
      #Nx.Tensor<
        f32[data: 3]
        [2.0, 4.0, 6.0]
      >

      iex> NxSgemm.multiply(2.0, Nx.tensor([1.0, 2.0, 3.0], names: [:data], type: :f32))
      #Nx.Tensor<
        f32[data: 3]
        [2.0, 4.0, 6.0]
      >
  """
  def multiply(a, b) when is_integer(a) and is_integer(b) do
    Nx.tensor(a * b, type: :s32)
  end

  def multiply(a, b) when is_float(b) do
    case Nx.type(a) do
      {:f, 32} ->
        %{
          a
          | data: %{
            a.data
            | state: mul_nif_f32_tensor_f32_scalar(Nx.size(a), a.data.state, b)
          }
        }
    end
  end

  def multiply(a, b) when is_integer(b) when 0 <= b and b < 256 do
    case Nx.type(a) do
      {:u, 8} ->
        %{
          a
          | data: %{
            a.data
            | state: mul_nif_u8_tensor_u8_scalar(Nx.size(a), a.data.state, b)
          }
        }
    end
  end

  def multiply(a, b) when is_number(a) do
    multiply(b, a)
  end

  defp mul_nif_f32_tensor_f32_scalar(_size, _a, _b), do: raise("NIF mul_nif_f32_tensor_f32_scalar/3 not implemented")
  defp mul_nif_u8_tensor_u8_scalar(_size, _a, _b), do: raise("NIF mul_nif_u8_tensor_u8_scalar/3 not implemented")
end

テスト方法

% git clone https://github.com/zacky1972/nx_sgemm.git
Cloning into 'nx_sgemm'...
remote: Enumerating objects: 48, done.
remote: Counting objects: 100% (48/48), done.
remote: Compressing objects: 100% (27/27), done.
Receiving objects: 100% (48/48), 8.99 KiB | 8.99 MiB/s, done.
Resolving deltas: 100% (17/17), done.
remote: Total 48 (delta 17), reused 44 (delta 13), pack-reused 0 (from 0)
% cd nx_sgemm 
nx_sgemm % % mix deps.get
Resolving Hex dependencies...
Resolution completed in 0.022s
Unchanged:
  complex 0.5.0
  ex_task 0.3.0
  finch 0.19.0
  hpax 1.0.1
  jason 1.4.4
  mime 2.0.6
  mint 1.6.2
  nimble_options 1.1.1
  nimble_pool 1.1.0
  nx 0.9.2
  req 0.5.8
  telemetry 1.3.0
* Getting ex_task (Hex package)
* Getting nx (Hex package)
* Getting complex (Hex package)
* Getting telemetry (Hex package)
* Getting req (Hex package)
* Getting finch (Hex package)
* Getting jason (Hex package)
* Getting mime (Hex package)
* Getting mint (Hex package)
* Getting nimble_options (Hex package)
* Getting nimble_pool (Hex package)
* Getting hpax (Hex package)
nx_sgemm % % mix test
==> mime
Compiling 1 file (.ex)
Generated mime app
==> nimble_options
Compiling 3 files (.ex)
Generated nimble_options app
===> Analyzing applications...
===> Compiling telemetry
==> jason
Compiling 10 files (.ex)
Generated jason app
==> hpax
Compiling 4 files (.ex)
Generated hpax app
==> mint
Compiling 1 file (.erl)
Compiling 20 files (.ex)
Generated mint app
==> complex
Compiling 2 files (.ex)
Generated complex app
==> nx
Compiling 36 files (.ex)
Generated nx app
==> nimble_pool
Compiling 2 files (.ex)
Generated nimble_pool app
==> finch
Compiling 14 files (.ex)
Generated finch app
==> req
Compiling 17 files (.ex)
Generated req app
==> ex_task
Compiling 2 files (.ex)
Generated ex_task app
go-task/task info checking GitHub for latest tag
go-task/task debug http_download https://github.com/go-task/task/releases/latest
go-task/task info found version: 3.40.1 for v3.40.1/darwin/arm64
go-task/task debug downloading files into /var/folders/s4/8mb121gd2y94rh09nqtlp4hw0000gn/T/tmp.QOYZrOHSUQ
go-task/task debug http_download https://github.com/go-task/task/releases/download/v3.40.1/task_darwin_arm64.tar.gz
go-task/task debug http_download https://github.com/go-task/task/releases/download/v3.40.1/task_checksums.txt
go-task/task info installed /Users/zacky/github/nx_sgemm/_build/test/lib/ex_task/bin/task

==> nx_sgemm
Compiling 1 file (.ex)
Generated nx_sgemm app
Running ExUnit with seed: 34060, max_cases: 32

......
Finished in 0.00 seconds (0.00s async, 0.00s sync)
6 doctests, 0 failures
4
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?