LoginSignup
12
1

More than 1 year has passed since last update.

Elixirで順に指定個数の素数を列挙する関数をNxでも作ってみた

Last updated at Posted at 2022-12-19

ふと興味を持ったので,エラトステネスのふるいを純朴に利用して,Elixirで小さいものから順に指定個数の素数を列挙する関数をEnum, Stream, Flowを使って作ってみたのですが,さらにNxでも作ってみました.

シリーズ

ソースコード全体

Mix.install(
  [
    {:nx, "~> 0.4.1"},
    {:exla, "~> 0.4.1"},
    {:flow, "~> 1.2"},
    {:benchee, "~> 1.1"}
  ],
  config: [
    nx: [default_backend: EXLA.Backend]
  ]
)

defmodule Prime do
  import Nx.Defn

  def prime_candidates() do
    Stream.unfold(2, fn
      2 -> {2, 3}
      n -> {n, n + 2}
    end)
  end

  def prime_from_index(index) do
    Enum.at(prime_candidates(), index)
  end

  def prime_nx(count) do
    pc =
      prime_candidates()
      |> Enum.take(count)
      |> Nx.tensor()

    pc
    |> prime_nx_sub(shape: {count, count})
    |> Nx.to_flat_list()
    |> Enum.reject(& &1 == 0)
  end

  defnp prime_nx_sub(pc, opts \\ []) do
    opts = keyword!(opts, shape: {1, 1})

    pc
    |> Nx.broadcast(opts[:shape])
    |> Nx.transpose()
    |> Nx.remainder(pc)
    |> Nx.add(Nx.eye(opts[:shape]))
    |> Nx.reduce_min([axes: [1]])
    |> Nx.multiply(pc)
  end

  def prime_enum(count) do
    prime_candidates()
    |> Stream.map(fn pr -> {pr, Stream.take_while(prime_candidates(), & &1 < pr)} end)
    |> Enum.take(count)
    |> Enum.map(fn {pr, prs} -> {pr, Enum.filter(prs, & rem(pr, &1) == 0)} end)
    |> Enum.filter(fn {_pr, divisors} -> Enum.count(divisors) == 0 end)
    |> Enum.map(fn {pr, _} -> pr end)
  end

  def prime_stream(count) do
    prime_candidates()
    |> Stream.map(fn pr -> {pr, Stream.take_while(prime_candidates(), & &1 < pr)} end)
    |> Stream.take(count)
    |> Stream.map(fn {pr, prs} -> {pr, Stream.filter(prs, & rem(pr, &1) == 0)} end)
    |> Stream.filter(fn {_pr, divisors} -> Enum.count(divisors) == 0 end)
    |> Stream.map(fn {pr, _} -> pr end)
    |> Enum.to_list()
  end

  def prime_flow(count) do
    prime_candidates()
    |> Stream.map(fn pr -> {pr, Stream.take_while(prime_candidates(), & &1 < pr)} end)
    |> Stream.take(count)
    |> Flow.from_enumerable(max_demand: 1)
    |> Flow.map(fn {pr, prs} -> {pr, Stream.filter(prs, & rem(pr, &1) == 0)} end)
    |> Flow.filter(fn {_pr, divisors} -> Enum.count(divisors) == 0 end)
    |> Flow.map(fn {pr, _} -> pr end)
    |> Enum.to_list()
  end
end

Benchee.run(
  %{
    "prime_enum" => fn count -> Prime.prime_enum(count) end,
    "prime_stream" => fn count -> Prime.prime_stream(count) end,
    "prime_flow" => fn count -> Prime.prime_flow(count) end,
    "prime_exla" => fn count -> Prime.prime_nx(count) end
  },
  inputs: %{
    "10" => 10,
    "100" => 100,
    "1000" => 1000,
    "10000" => 10000
  }
)

コード解説

  def prime_from_index(index) do
    Enum.at(prime_candidates(), index)
  end

素数候補(prime_candidates)のindex番目の要素を取り出します.

  def prime_nx(count) do
    pc =
      prime_candidates()
      |> Enum.take(count)
      |> Nx.tensor()

素数候補(prime_candidates)をcount番目の要素まで取り出して,ベクトルにします.これを変数pcとします.

  def prime_nx(count) do
    pc =
      prime_candidates()
      |> Enum.take(count)
      |> Nx.tensor()

    pc
    |> prime_nx_sub(shape: {count, count})
  defnp prime_nx_sub(pc, count) do
    opts = keyword!(opts, shape: {1, 1})

    pc
    |> Nx.broadcast(opts[:shape])

素数候補(prime_candidates)をcount番目の要素まで取り出したベクトルを,count行並べた行列を作ります.またshapeを渡すには,Nx.Defn.Kernel.keyword!/2 https://hexdocs.pm/nx/Nx.Defn.Kernel.html#keyword!/2 を使います.例えばcount == 10の時には次のようになります.

#Nx.Tensor<
  s64[10][10]
  EXLA.Backend<host:0, 0.3485761066.2791702544.238252>
  [
    [2, 3, 5, 7, 9, 11, 13, 15, 17, 19],
    [2, 3, 5, 7, 9, 11, 13, 15, 17, 19],
    [2, 3, 5, 7, 9, 11, 13, 15, 17, 19],
    [2, 3, 5, 7, 9, 11, 13, 15, 17, 19],
    [2, 3, 5, 7, 9, 11, 13, 15, 17, 19],
    ...
  ]
>

次に転置します.

  def prime_nx(count) do
    pc =
      prime_candidates()
      |> Enum.take(count)
      |> Nx.tensor()

    pc
    |> prime_nx_sub(shape: {count, count})
  defnp prime_nx_sub(pc, count) do
    opts = keyword!(opts, shape: {1, 1})

    pc
    |> Nx.broadcast(opts[:shape])
    |> Nx.transpose()

count == 10の時には次のようになります.

#Nx.Tensor<
  s64[10][10]
  EXLA.Backend<host:0, 0.3485761066.2791702544.238254>
  [
    [2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
    [3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
    [5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
    [7, 7, 7, 7, 7, 7, 7, 7, 7, 7],
    [9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
    ...
  ]
>

これとベクトルpcを照らし合わせて余りを取ります.

  def prime_nx(count) do
    pc =
      prime_candidates()
      |> Enum.take(count)
      |> Nx.tensor()

    pc
    |> prime_nx_sub(shape: {count, count})
  defnp prime_nx_sub(pc, count) do
    opts = keyword!(opts, shape: {1, 1})

    pc
    |> Nx.broadcast(opts[:shape])
    |> Nx.transpose()
    |> Nx.remainder(pc)

次のようになります.

#Nx.Tensor<
  s64[10][10]
  EXLA.Backend<host:0, 0.3485761066.2791702544.238256>
  [
    [0, 2, 2, 2, 2, 2, 2, 2, 2, 2],
    [1, 0, 3, 3, 3, 3, 3, 3, 3, 3],
    [1, 2, 0, 5, 5, 5, 5, 5, 5, 5],
    [1, 1, 2, 0, 7, 7, 7, 7, 7, 7],
    [1, 0, 4, 2, 0, 9, 9, 9, 9, 9],
    ...
  ]
>

行と列が等しい場合を除いた時に余りが0になるような行に対応する素数候補は,0と自分自身以外の約数を持つので素数ではないということになります.行と列が等しい場合の01にするために,単位行列を足します.これにより,素数の場合は行の最小値が1に,素数でない場合は行の最小値が0になります.

  def prime_nx(count) do
    pc =
      prime_candidates()
      |> Enum.take(count)
      |> Nx.tensor()

    pc
    |> prime_nx_sub(shape: {count, count})
  defnp prime_nx_sub(pc, count) do
    opts = keyword!(opts, shape: {1, 1})

    pc
    |> Nx.broadcast(opts[:shape])
    |> Nx.transpose()
    |> Nx.remainder(pc)
    |> Nx.add(Nx.eye(opts[:shape]))

すると次のようになります.

#Nx.Tensor<
  s64[10][10]
  EXLA.Backend<host:0, 0.3485761066.2791702544.238260>
  [
    [1, 2, 2, 2, 2, 2, 2, 2, 2, 2],
    [1, 1, 3, 3, 3, 3, 3, 3, 3, 3],
    [1, 2, 1, 5, 5, 5, 5, 5, 5, 5],
    [1, 1, 2, 1, 7, 7, 7, 7, 7, 7],
    [1, 0, 4, 2, 1, 9, 9, 9, 9, 9],
    ...
  ]
>

次に行単位で集約して,最小値を計算します.前述のように最小値は1もしくは0になるはずで,最小値が0になるものは素数ではありません.行単位で集約するために,axes: [1]を付記します.この計算が高速化の肝になります.

  def prime_nx(count) do
    pc =
      prime_candidates()
      |> Enum.take(count)
      |> Nx.tensor()

    pc
    |> prime_nx_sub(shape: {count, count})
  defnp prime_nx_sub(pc, count) do
    opts = keyword!(opts, shape: {1, 1})

    pc
    |> Nx.broadcast(opts[:shape])
    |> Nx.transpose()
    |> Nx.remainder(pc)
    |> Nx.add(Nx.eye(opts[:shape]))
    |> Nx.reduce_min([axes: [1]])

すると次のようになります.

#Nx.Tensor<
  s64[10]
  EXLA.Backend<host:0, 0.3485761066.2791702544.238262>
  [1, 1, 1, 1, 0, 1, 1, 0, 1, 1]
>

これに素数候補ベクトルpcを乗じます.

  def prime_nx(count) do
    pc =
      prime_candidates()
      |> Enum.take(count)
      |> Nx.tensor()

    pc
    |> prime_nx_sub(shape: {count, count})
  defnp prime_nx_sub(pc, count) do
    opts = keyword!(opts, shape: {1, 1})

    pc
    |> Nx.broadcast(opts[:shape])
    |> Nx.transpose()
    |> Nx.remainder(pc)
    |> Nx.add(Nx.eye(opts[:shape]))
    |> Nx.reduce_min([axes: [1]])
    |> Nx.multiply(pc)
  end

すると次のようになります.ほぼ完成形ですね.

 #Nx.Tensor<
  s64[10]
  EXLA.Backend<host:0, 0.3485761066.2791702544.238264>
  [2, 3, 5, 7, 0, 11, 13, 0, 17, 19]
>

あとはリストにして,0の要素を取り除きます.

  def prime_nx(count) do
    pc =
      prime_candidates()
      |> Enum.take(count)
      |> Nx.tensor()

    pc
    |> prime_nx_sub(shape: {count, count})
    |> Nx.to_flat_list()
    |> Enum.reject(& &1 == 0)
  end

  defnp prime_nx_sub(pc, count) do
    opts = keyword!(opts, shape: {1, 1})

    pc
    |> Nx.broadcast(opts[:shape])
    |> Nx.transpose()
    |> Nx.remainder(pc)
    |> Nx.add(Nx.eye(opts[:shape]))
    |> Nx.reduce_min([axes: [1]])
    |> Nx.multiply(pc)
  end

ベンチマーク結果(MacStudio on M1 Ultra)

% elixir prime_nx.exs
Operating System: macOS
CPU Information: Apple M1 Ultra
Number of Available Cores: 20
Available memory: 128 GB
Elixir 1.14.2
Erlang 25.2

Benchmark suite executing with the following configuration:
warmup: 2 s
time: 5 s
memory time: 0 ns
reduction time: 0 ns
parallel: 1
inputs: 10, 100, 1000, 10000
Estimated total run time: 1.87 min

Benchmarking prime_enum with input 10 ...
Benchmarking prime_enum with input 100 ...
Benchmarking prime_enum with input 1000 ...
Benchmarking prime_enum with input 10000 ...
Benchmarking prime_exla with input 10 ...

04:48:17.360 [info] TfrtCpuClient created.
Benchmarking prime_exla with input 100 ...
Benchmarking prime_exla with input 1000 ...
Benchmarking prime_exla with input 10000 ...
Benchmarking prime_flow with input 10 ...
Benchmarking prime_flow with input 100 ...
Benchmarking prime_flow with input 1000 ...
Benchmarking prime_flow with input 10000 ...
Benchmarking prime_stream with input 10 ...
Benchmarking prime_stream with input 100 ...
Benchmarking prime_stream with input 1000 ...
Benchmarking prime_stream with input 10000 ...

##### With input 10 #####
Name                   ips        average  deviation         median         99th %
prime_enum        267.22 K        3.74 μs   ±305.23%        3.25 μs        8.38 μs
prime_stream      264.72 K        3.78 μs   ±273.67%        3.29 μs          10 μs
prime_flow          2.71 K      368.63 μs    ±11.17%      360.75 μs      507.38 μs
prime_exla          1.58 K      631.20 μs    ±13.61%      627.04 μs      873.26 μs

Comparison: 
prime_enum        267.22 K
prime_stream      264.72 K - 1.01x slower +0.0352 μs
prime_flow          2.71 K - 98.50x slower +364.88 μs
prime_exla          1.58 K - 168.67x slower +627.45 μs

##### With input 100 #####
Name                   ips        average  deviation         median         99th %
prime_stream        8.92 K      112.10 μs    ±20.23%      104.83 μs      188.04 μs
prime_enum          6.29 K      159.06 μs    ±16.49%      158.13 μs      277.16 μs
prime_exla          1.49 K      669.69 μs    ±12.57%      668.77 μs      889.64 μs
prime_flow          1.39 K      718.27 μs    ±13.27%      705.04 μs      979.99 μs

Comparison: 
prime_stream        8.92 K
prime_enum          6.29 K - 1.42x slower +46.97 μs
prime_exla          1.49 K - 5.97x slower +557.59 μs
prime_flow          1.39 K - 6.41x slower +606.17 μs

##### With input 1000 #####
Name                   ips        average  deviation         median         99th %
prime_exla          273.52        3.66 ms     ±3.71%        3.64 ms        4.05 ms
prime_flow          226.12        4.42 ms     ±5.25%        4.41 ms        5.16 ms
prime_stream        123.90        8.07 ms     ±1.32%        8.06 ms        8.31 ms
prime_enum           76.80       13.02 ms     ±5.59%       13.23 ms       13.68 ms

Comparison: 
prime_exla          273.52
prime_flow          226.12 - 1.21x slower +0.77 ms
prime_stream        123.90 - 2.21x slower +4.41 ms
prime_enum           76.80 - 3.56x slower +9.36 ms

##### With input 10000 #####
Name                   ips        average  deviation         median         99th %
prime_flow           11.26       88.79 ms     ±0.91%       88.67 ms       90.97 ms
prime_exla            3.51      284.81 ms     ±0.42%      284.53 ms      288.39 ms
prime_stream          1.29      775.55 ms     ±0.20%      775.77 ms      778.11 ms
prime_enum            0.86     1158.24 ms     ±0.77%     1155.26 ms     1173.25 ms

Comparison: 
prime_flow           11.26
prime_exla            3.51 - 3.21x slower +196.02 ms
prime_stream          1.29 - 8.73x slower +686.76 ms
prime_enum            0.86 - 13.04x slower +1069.45 ms
  • 10個の時にはEnumが最も速いです.
  • 100個の時にはStreamが最も速いです.
  • 1000個の時にはNx(EXLA)が最も速いです.
  • 10000個の時にはFlowが最も速いです.
12
1
9

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
1