ふと興味を持ったので,エラトステネスのふるいを純朴に利用して,Elixirで小さいものから順に指定個数の素数を列挙する関数をEnum, Stream, Flowを使って作ってみたのですが,さらにNxでも作ってみました.
シリーズ
ソースコード全体
Mix.install(
[
{:nx, "~> 0.4.1"},
{:exla, "~> 0.4.1"},
{:flow, "~> 1.2"},
{:benchee, "~> 1.1"}
],
config: [
nx: [default_backend: EXLA.Backend]
]
)
defmodule Prime do
import Nx.Defn
def prime_candidates() do
Stream.unfold(2, fn
2 -> {2, 3}
n -> {n, n + 2}
end)
end
def prime_from_index(index) do
Enum.at(prime_candidates(), index)
end
def prime_nx(count) do
pc =
prime_candidates()
|> Enum.take(count)
|> Nx.tensor()
pc
|> prime_nx_sub(shape: {count, count})
|> Nx.to_flat_list()
|> Enum.reject(& &1 == 0)
end
defnp prime_nx_sub(pc, opts \\ []) do
opts = keyword!(opts, shape: {1, 1})
pc
|> Nx.broadcast(opts[:shape])
|> Nx.transpose()
|> Nx.remainder(pc)
|> Nx.add(Nx.eye(opts[:shape]))
|> Nx.reduce_min([axes: [1]])
|> Nx.multiply(pc)
end
def prime_enum(count) do
prime_candidates()
|> Stream.map(fn pr -> {pr, Stream.take_while(prime_candidates(), & &1 < pr)} end)
|> Enum.take(count)
|> Enum.map(fn {pr, prs} -> {pr, Enum.filter(prs, & rem(pr, &1) == 0)} end)
|> Enum.filter(fn {_pr, divisors} -> Enum.count(divisors) == 0 end)
|> Enum.map(fn {pr, _} -> pr end)
end
def prime_stream(count) do
prime_candidates()
|> Stream.map(fn pr -> {pr, Stream.take_while(prime_candidates(), & &1 < pr)} end)
|> Stream.take(count)
|> Stream.map(fn {pr, prs} -> {pr, Stream.filter(prs, & rem(pr, &1) == 0)} end)
|> Stream.filter(fn {_pr, divisors} -> Enum.count(divisors) == 0 end)
|> Stream.map(fn {pr, _} -> pr end)
|> Enum.to_list()
end
def prime_flow(count) do
prime_candidates()
|> Stream.map(fn pr -> {pr, Stream.take_while(prime_candidates(), & &1 < pr)} end)
|> Stream.take(count)
|> Flow.from_enumerable(max_demand: 1)
|> Flow.map(fn {pr, prs} -> {pr, Stream.filter(prs, & rem(pr, &1) == 0)} end)
|> Flow.filter(fn {_pr, divisors} -> Enum.count(divisors) == 0 end)
|> Flow.map(fn {pr, _} -> pr end)
|> Enum.to_list()
end
end
Benchee.run(
%{
"prime_enum" => fn count -> Prime.prime_enum(count) end,
"prime_stream" => fn count -> Prime.prime_stream(count) end,
"prime_flow" => fn count -> Prime.prime_flow(count) end,
"prime_exla" => fn count -> Prime.prime_nx(count) end
},
inputs: %{
"10" => 10,
"100" => 100,
"1000" => 1000,
"10000" => 10000
}
)
コード解説
def prime_from_index(index) do
Enum.at(prime_candidates(), index)
end
素数候補(prime_candidates)のindex
番目の要素を取り出します.
def prime_nx(count) do
pc =
prime_candidates()
|> Enum.take(count)
|> Nx.tensor()
素数候補(prime_candidates)をcount
番目の要素まで取り出して,ベクトルにします.これを変数pc
とします.
def prime_nx(count) do
pc =
prime_candidates()
|> Enum.take(count)
|> Nx.tensor()
pc
|> prime_nx_sub(shape: {count, count})
defnp prime_nx_sub(pc, count) do
opts = keyword!(opts, shape: {1, 1})
pc
|> Nx.broadcast(opts[:shape])
素数候補(prime_candidates)をcount
番目の要素まで取り出したベクトルを,count
行並べた行列を作ります.またshape
を渡すには,Nx.Defn.Kernel.keyword!/2
https://hexdocs.pm/nx/Nx.Defn.Kernel.html#keyword!/2 を使います.例えばcount == 10
の時には次のようになります.
#Nx.Tensor<
s64[10][10]
EXLA.Backend<host:0, 0.3485761066.2791702544.238252>
[
[2, 3, 5, 7, 9, 11, 13, 15, 17, 19],
[2, 3, 5, 7, 9, 11, 13, 15, 17, 19],
[2, 3, 5, 7, 9, 11, 13, 15, 17, 19],
[2, 3, 5, 7, 9, 11, 13, 15, 17, 19],
[2, 3, 5, 7, 9, 11, 13, 15, 17, 19],
...
]
>
次に転置します.
def prime_nx(count) do
pc =
prime_candidates()
|> Enum.take(count)
|> Nx.tensor()
pc
|> prime_nx_sub(shape: {count, count})
defnp prime_nx_sub(pc, count) do
opts = keyword!(opts, shape: {1, 1})
pc
|> Nx.broadcast(opts[:shape])
|> Nx.transpose()
count == 10
の時には次のようになります.
#Nx.Tensor<
s64[10][10]
EXLA.Backend<host:0, 0.3485761066.2791702544.238254>
[
[2, 2, 2, 2, 2, 2, 2, 2, 2, 2],
[3, 3, 3, 3, 3, 3, 3, 3, 3, 3],
[5, 5, 5, 5, 5, 5, 5, 5, 5, 5],
[7, 7, 7, 7, 7, 7, 7, 7, 7, 7],
[9, 9, 9, 9, 9, 9, 9, 9, 9, 9],
...
]
>
これとベクトルpc
を照らし合わせて余りを取ります.
def prime_nx(count) do
pc =
prime_candidates()
|> Enum.take(count)
|> Nx.tensor()
pc
|> prime_nx_sub(shape: {count, count})
defnp prime_nx_sub(pc, count) do
opts = keyword!(opts, shape: {1, 1})
pc
|> Nx.broadcast(opts[:shape])
|> Nx.transpose()
|> Nx.remainder(pc)
次のようになります.
#Nx.Tensor<
s64[10][10]
EXLA.Backend<host:0, 0.3485761066.2791702544.238256>
[
[0, 2, 2, 2, 2, 2, 2, 2, 2, 2],
[1, 0, 3, 3, 3, 3, 3, 3, 3, 3],
[1, 2, 0, 5, 5, 5, 5, 5, 5, 5],
[1, 1, 2, 0, 7, 7, 7, 7, 7, 7],
[1, 0, 4, 2, 0, 9, 9, 9, 9, 9],
...
]
>
行と列が等しい場合を除いた時に余りが0
になるような行に対応する素数候補は,0
と自分自身以外の約数を持つので素数ではないということになります.行と列が等しい場合の0
を1
にするために,単位行列を足します.これにより,素数の場合は行の最小値が1
に,素数でない場合は行の最小値が0
になります.
def prime_nx(count) do
pc =
prime_candidates()
|> Enum.take(count)
|> Nx.tensor()
pc
|> prime_nx_sub(shape: {count, count})
defnp prime_nx_sub(pc, count) do
opts = keyword!(opts, shape: {1, 1})
pc
|> Nx.broadcast(opts[:shape])
|> Nx.transpose()
|> Nx.remainder(pc)
|> Nx.add(Nx.eye(opts[:shape]))
すると次のようになります.
#Nx.Tensor<
s64[10][10]
EXLA.Backend<host:0, 0.3485761066.2791702544.238260>
[
[1, 2, 2, 2, 2, 2, 2, 2, 2, 2],
[1, 1, 3, 3, 3, 3, 3, 3, 3, 3],
[1, 2, 1, 5, 5, 5, 5, 5, 5, 5],
[1, 1, 2, 1, 7, 7, 7, 7, 7, 7],
[1, 0, 4, 2, 1, 9, 9, 9, 9, 9],
...
]
>
次に行単位で集約して,最小値を計算します.前述のように最小値は1
もしくは0
になるはずで,最小値が0
になるものは素数ではありません.行単位で集約するために,axes: [1]
を付記します.この計算が高速化の肝になります.
def prime_nx(count) do
pc =
prime_candidates()
|> Enum.take(count)
|> Nx.tensor()
pc
|> prime_nx_sub(shape: {count, count})
defnp prime_nx_sub(pc, count) do
opts = keyword!(opts, shape: {1, 1})
pc
|> Nx.broadcast(opts[:shape])
|> Nx.transpose()
|> Nx.remainder(pc)
|> Nx.add(Nx.eye(opts[:shape]))
|> Nx.reduce_min([axes: [1]])
すると次のようになります.
#Nx.Tensor<
s64[10]
EXLA.Backend<host:0, 0.3485761066.2791702544.238262>
[1, 1, 1, 1, 0, 1, 1, 0, 1, 1]
>
これに素数候補ベクトルpc
を乗じます.
def prime_nx(count) do
pc =
prime_candidates()
|> Enum.take(count)
|> Nx.tensor()
pc
|> prime_nx_sub(shape: {count, count})
defnp prime_nx_sub(pc, count) do
opts = keyword!(opts, shape: {1, 1})
pc
|> Nx.broadcast(opts[:shape])
|> Nx.transpose()
|> Nx.remainder(pc)
|> Nx.add(Nx.eye(opts[:shape]))
|> Nx.reduce_min([axes: [1]])
|> Nx.multiply(pc)
end
すると次のようになります.ほぼ完成形ですね.
#Nx.Tensor<
s64[10]
EXLA.Backend<host:0, 0.3485761066.2791702544.238264>
[2, 3, 5, 7, 0, 11, 13, 0, 17, 19]
>
あとはリストにして,0
の要素を取り除きます.
def prime_nx(count) do
pc =
prime_candidates()
|> Enum.take(count)
|> Nx.tensor()
pc
|> prime_nx_sub(shape: {count, count})
|> Nx.to_flat_list()
|> Enum.reject(& &1 == 0)
end
defnp prime_nx_sub(pc, count) do
opts = keyword!(opts, shape: {1, 1})
pc
|> Nx.broadcast(opts[:shape])
|> Nx.transpose()
|> Nx.remainder(pc)
|> Nx.add(Nx.eye(opts[:shape]))
|> Nx.reduce_min([axes: [1]])
|> Nx.multiply(pc)
end
ベンチマーク結果(MacStudio on M1 Ultra)
% elixir prime_nx.exs
Operating System: macOS
CPU Information: Apple M1 Ultra
Number of Available Cores: 20
Available memory: 128 GB
Elixir 1.14.2
Erlang 25.2
Benchmark suite executing with the following configuration:
warmup: 2 s
time: 5 s
memory time: 0 ns
reduction time: 0 ns
parallel: 1
inputs: 10, 100, 1000, 10000
Estimated total run time: 1.87 min
Benchmarking prime_enum with input 10 ...
Benchmarking prime_enum with input 100 ...
Benchmarking prime_enum with input 1000 ...
Benchmarking prime_enum with input 10000 ...
Benchmarking prime_exla with input 10 ...
04:48:17.360 [info] TfrtCpuClient created.
Benchmarking prime_exla with input 100 ...
Benchmarking prime_exla with input 1000 ...
Benchmarking prime_exla with input 10000 ...
Benchmarking prime_flow with input 10 ...
Benchmarking prime_flow with input 100 ...
Benchmarking prime_flow with input 1000 ...
Benchmarking prime_flow with input 10000 ...
Benchmarking prime_stream with input 10 ...
Benchmarking prime_stream with input 100 ...
Benchmarking prime_stream with input 1000 ...
Benchmarking prime_stream with input 10000 ...
##### With input 10 #####
Name ips average deviation median 99th %
prime_enum 267.22 K 3.74 μs ±305.23% 3.25 μs 8.38 μs
prime_stream 264.72 K 3.78 μs ±273.67% 3.29 μs 10 μs
prime_flow 2.71 K 368.63 μs ±11.17% 360.75 μs 507.38 μs
prime_exla 1.58 K 631.20 μs ±13.61% 627.04 μs 873.26 μs
Comparison:
prime_enum 267.22 K
prime_stream 264.72 K - 1.01x slower +0.0352 μs
prime_flow 2.71 K - 98.50x slower +364.88 μs
prime_exla 1.58 K - 168.67x slower +627.45 μs
##### With input 100 #####
Name ips average deviation median 99th %
prime_stream 8.92 K 112.10 μs ±20.23% 104.83 μs 188.04 μs
prime_enum 6.29 K 159.06 μs ±16.49% 158.13 μs 277.16 μs
prime_exla 1.49 K 669.69 μs ±12.57% 668.77 μs 889.64 μs
prime_flow 1.39 K 718.27 μs ±13.27% 705.04 μs 979.99 μs
Comparison:
prime_stream 8.92 K
prime_enum 6.29 K - 1.42x slower +46.97 μs
prime_exla 1.49 K - 5.97x slower +557.59 μs
prime_flow 1.39 K - 6.41x slower +606.17 μs
##### With input 1000 #####
Name ips average deviation median 99th %
prime_exla 273.52 3.66 ms ±3.71% 3.64 ms 4.05 ms
prime_flow 226.12 4.42 ms ±5.25% 4.41 ms 5.16 ms
prime_stream 123.90 8.07 ms ±1.32% 8.06 ms 8.31 ms
prime_enum 76.80 13.02 ms ±5.59% 13.23 ms 13.68 ms
Comparison:
prime_exla 273.52
prime_flow 226.12 - 1.21x slower +0.77 ms
prime_stream 123.90 - 2.21x slower +4.41 ms
prime_enum 76.80 - 3.56x slower +9.36 ms
##### With input 10000 #####
Name ips average deviation median 99th %
prime_flow 11.26 88.79 ms ±0.91% 88.67 ms 90.97 ms
prime_exla 3.51 284.81 ms ±0.42% 284.53 ms 288.39 ms
prime_stream 1.29 775.55 ms ±0.20% 775.77 ms 778.11 ms
prime_enum 0.86 1158.24 ms ±0.77% 1155.26 ms 1173.25 ms
Comparison:
prime_flow 11.26
prime_exla 3.51 - 3.21x slower +196.02 ms
prime_stream 1.29 - 8.73x slower +686.76 ms
prime_enum 0.86 - 13.04x slower +1069.45 ms
- 10個の時にはEnumが最も速いです.
- 100個の時にはStreamが最も速いです.
- 1000個の時にはNx(EXLA)が最も速いです.
- 10000個の時にはFlowが最も速いです.