2024年になりました.新年最初のQiita記事として,2023年末にかなり苦労して実装した関数「Streamによる無限長のビット列の右シフト演算」をどのように実装したかを紹介します.
実装例
Elixirがインストールされているとします.まず,次のコマンドを打ち込みます.
mix new bit_shifter
cd bit_shifter
mkdir lib/bit_shifter
mkdir test/bit_shifter
前準備: BitShifter.ListWrapper
まずリストをStream化するに沿って ListWrapper
モジュールを作りましょう.
テストコードからです.
defmodule BitShifter.ListWrapperTest do
use ExUnit.Case
doctest BitShifter.ListWrapper
test "get([])" do
BitShifter.ListWrapper.get([])
|> Stream.map(&assert &1 == 0)
|> Enum.take(10)
end
test "get([1, 2])" do
BitShifter.ListWrapper.get([1, 2])
|> Stream.with_index()
|> Stream.map(fn
{v, 0} -> assert v == 1
{v, 1} -> assert v == 2
{v, _} -> assert v == 0
end)
|> Enum.take(10)
end
end
当たり前ですが,テストは通りません.
bit_shifter% mix test
Compiling 1 file (.ex)
Generated bit_shifter app
error: module BitShifter.ListWrapper is not loaded and could not be found
│
3 │ doctest BitShifter.ListWrapper
│ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
│
└─ test/bit_shifter/list_wrapper_test.exs:3: BitShifter.ListWrapperTest (module)
== Compilation error in file test/bit_shifter/list_wrapper_test.exs ==
** (CompileError) test/bit_shifter/list_wrapper_test.exs: cannot compile module BitShifter.ListWrapperTest (errors have been logged)
(ex_unit 1.16.0) expanding macro: ExUnit.DocTest.doctest/1
test/bit_shifter/list_wrapper_test.exs:3: BitShifter.ListWrapperTest (module)
bit_shifter%
次にモジュールBitShifter.ListWrapper
を実装します.
defmodule BitShifter.ListWrapper do
@moduledoc """
A module to wrap a list with `Stream`.
"""
@spec get(list(non_neg_integer()) | Enumerable.t(), non_neg_integer()) :: Enumerable.t()
@doc """
Returns the given `list` wrapped in a `Stream`
such that the value following the list is `default_value`.
"""
def get(list, default_value \\ 0)
def get(list, default_value) when is_list(list) do
Stream.unfold(list, fn
[] -> {default_value, []}
[h | t] -> {h, t}
end)
end
def get(range, default_value) when is_struct(range, Range) do
Stream.unfold(Enum.to_list(range), fn
[] -> {default_value, []}
[h | t] -> {h, t}
end)
end
def get(stream, _default_value), do: stream
end
リストをStream化するでも例示したように,このコードそのものは,任意の値のリストlist
,任意のデフォルト値default_value
を受け付けるのですが,ビット列を表すようにしたいので,Typespecsをキツめに,それぞれ,非負の整数値non_neg_integer()
が入るようにしています.また,関数パターンマッチを使っています.本当は,関数パターンマッチで,Stream
の場合だけ,素のまま返すとしたいのですが,うまくマッチする方法を見出せませんでした.
mix test
すると,テストが通ります.
bit_shifter % mix test
....
Finished in 0.01 seconds (0.00s async, 0.01s sync)
1 doctest, 3 tests, 0 failures
本題: BitShifter
ここでようやく本題です.
shift = 0
の場合
テストコードを作ります.
defmodule BitShifterTest do
use ExUnit.Case
doctest BitShifter
describe "get" do
test "shift = 0" do
[0xAA, 0x55]
|> BitShifter.ListWrapper.get(0)
|> BitShifter.get(0)
|> Stream.with_index()
|> Stream.map(fn
{n, 0} -> assert n == 0xAA
{n, 1} -> assert n == 0x55
{n, _} -> assert n == 0
end)
|> Enum.take(10)
end
end
end
mix test
すると,エラーになることを確認します.まだ実装していないので,これでOKです.
bit_shifter % mix test
warning: BitShifter.get/2 is undefined or private
│
9 │ |> BitShifter.get(0)
│ ~
│
└─ test/bit_shifter_test.exs:9:21: BitShifterTest."test get shift = 0"/1
.
1) test get shift = 0 (BitShifterTest)
test/bit_shifter_test.exs:6
** (UndefinedFunctionError) function BitShifter.get/2 is undefined or private
code: |> BitShifter.get(0)
stacktrace:
(bit_shifter 0.1.0) BitShifter.get(#Function<63.53678557/2 in Stream.unfold/2>, 0)
test/bit_shifter_test.exs:9: (test)
..
Finished in 0.01 seconds (0.00s async, 0.01s sync)
1 doctest, 3 tests, 1 failure
Randomized with seed 590452
bit_shifter %
次に,lib/bit_shifter.ex
を次のように実装します.
defmodule BitShifter do
@moduledoc """
A module to right shift an infinite length byte `Stream`.
"""
@spec get(Enumerable.t(), non_neg_integer()) :: Enumerable.t()
@doc """
Returns the result of shifting the given infinite length byte `stream`
by the given `shift` bits to the right.
"""
def get(stream, shift)
def get(stream, 0), do: stream
end
そうすると,mix test
が通ります.
% mix test
...
Finished in 0.01 seconds (0.00s async, 0.01s sync)
3 tests, 0 failures
shift = 8
, shift = 16
の場合
次に8ビットシフトと16ビットシフトを足します.
defmodule BitShifterTest do
use ExUnit.Case
doctest BitShifter
describe "get" do
test "shift = 0" do
[0xAA, 0x55]
|> BitShifter.ListWrapper.get(0)
|> BitShifter.get(0)
|> Stream.with_index()
|> Stream.map(fn
{n, 0} -> assert n == 0xAA
{n, 1} -> assert n == 0x55
{n, _} -> assert n == 0
end)
|> Enum.take(10)
end
test "shift = 8" do
[0xAA, 0x55]
|> BitShifter.ListWrapper.get(0)
|> BitShifter.get(8)
|> Stream.with_index()
|> Stream.map(fn
{n, 0} -> assert n == 0
{n, 1} -> assert n == 0xAA
{n, 2} -> assert n == 0x55
{n, _} -> assert n == 0
end)
|> Enum.take(10)
end
test "shift = 16" do
[0xAA, 0x55]
|> BitShifter.ListWrapper.get(0)
|> BitShifter.get(16)
|> Stream.with_index()
|> Stream.map(fn
{n, 0} -> assert n == 0
{n, 1} -> assert n == 0
{n, 2} -> assert n == 0xAA
{n, 3} -> assert n == 0x55
{n, _} -> assert n == 0
end)
|> Enum.take(10)
end
end
end
この時,当然ながらテストは失敗します.
次に実装します.
defmodule BitShifter do
@moduledoc """
A module to right shift an infinite length byte `Stream`.
"""
@spec get(Enumerable.t(), non_neg_integer()) :: Enumerable.t()
@doc """
Returns the result of shifting the given infinite length byte `stream`
by the given `shift` bits to the right.
"""
def get(stream, shift)
def get(stream, 0), do: stream
def get(stream, shift) when is_integer(shift) and 8 <= shift do
Stream.unfold(
{Bitwise.bsr(shift, 3), get(stream, Bitwise.band(shift, 7))},
fn
{0, s} -> {Enum.at(s, 0), {0, Stream.drop(s, 1)}}
{n, s} -> {0, {n - 1, s}}
end
)
end
end
これで,テストが通ることを確認します.
シフト値からマスクを生成する関数BitShifter.MaskFromShift.get/1
を作成する
いよいよshift
が任意の場合にチャレンジするわけですが,その前に次のような関数BitShifter.MaskFromShift.get/1
を作ります.
テストコードで仕様を見ていきます.
defmodule BitShifter.MaskFromShiftTest do
use ExUnit.Case
doctest BitShifter.MaskFromShift
describe "get" do
test "0" do
assert BitShifter.MaskFromShift.get(0) == 0
end
test "1" do
assert BitShifter.MaskFromShift.get(1) == 1
end
test "2" do
assert BitShifter.MaskFromShift.get(2) == 3
end
test "3" do
assert BitShifter.MaskFromShift.get(3) == 7
end
test "4" do
assert BitShifter.MaskFromShift.get(4) == 0xF
end
test "5" do
assert BitShifter.MaskFromShift.get(5) == 0x1F
end
test "6" do
assert BitShifter.MaskFromShift.get(6) == 0x3F
end
test "7" do
assert BitShifter.MaskFromShift.get(7) == 0x7F
end
end
end
実装は次のようになります.
defmodule BitShifter.MaskFromShift do
@moduledoc """
A module that generates mask bits corresponding to the shift value.
"""
@spec get(non_neg_integer()) :: non_neg_integer()
@doc """
Returns mask bits corresponding to the given `shift` value.
"""
def get(shift)
def get(0), do: 0
def get(shift) when is_number(shift) and 0 < shift and shift < 8 do
shift = shift - 1
Enum.reduce(0..shift, 0, fn _, acc ->
acc
|> Bitwise.bsl(1)
|> Bitwise.bor(1)
end)
end
end
shift
を8で割った時の余りが0以外の場合
次がいよいよ山場です.
テストを実装します.
defmodule BitShifterTest do
use ExUnit.Case
doctest BitShifter
describe "get" do
test "shift = 0" do
[0xAA, 0x55]
|> BitShifter.ListWrapper.get(0)
|> BitShifter.get(0)
|> Stream.with_index()
|> Stream.map(fn
{n, 0} -> assert n == 0xAA
{n, 1} -> assert n == 0x55
{n, _} -> assert n == 0
end)
|> Enum.take(10)
end
test "shift = 8" do
[0xAA, 0x55]
|> BitShifter.ListWrapper.get(0)
|> BitShifter.get(8)
|> Stream.with_index()
|> Stream.map(fn
{n, 0} -> assert n == 0
{n, 1} -> assert n == 0xAA
{n, 2} -> assert n == 0x55
{n, _} -> assert n == 0
end)
|> Enum.take(10)
end
test "shift = 16" do
[0xAA, 0x55]
|> BitShifter.ListWrapper.get(0)
|> BitShifter.get(16)
|> Stream.with_index()
|> Stream.map(fn
{n, 0} -> assert n == 0
{n, 1} -> assert n == 0
{n, 2} -> assert n == 0xAA
{n, 3} -> assert n == 0x55
{n, _} -> assert n == 0
end)
|> Enum.take(10)
end
test "shift = 1" do
[0xAA, 0x55]
|> BitShifter.ListWrapper.get(0)
|> BitShifter.get(1)
|> Stream.with_index()
|> Stream.map(fn
{n, 0} -> assert n == 0x55
{n, 1} -> assert n == 0x2A
{n, 2} -> assert n == 0x80
{n, _} -> assert n == 0
end)
|> Enum.take(10)
end
test "shift = 2" do
[0xAA, 0x55]
|> BitShifter.ListWrapper.get(0)
|> BitShifter.get(2)
|> Stream.with_index()
|> Stream.map(fn
{n, 0} -> assert n == 0x2A
{n, 1} -> assert n == 0x95
{n, 2} -> assert n == 0x40
{n, _} -> assert n == 0
end)
|> Enum.take(10)
end
test "shift = 3" do
[0xAA, 0x55]
|> BitShifter.ListWrapper.get(0)
|> BitShifter.get(3)
|> Stream.with_index()
|> Stream.map(fn
{n, 0} -> assert n == 0x15
{n, 1} -> assert n == 0x4A
{n, 2} -> assert n == 0xA0
{n, _} -> assert n == 0
end)
|> Enum.take(10)
end
test "shift = 4" do
[0xAA, 0x55]
|> BitShifter.ListWrapper.get(0)
|> BitShifter.get(4)
|> Stream.with_index()
|> Stream.map(fn
{n, 0} -> assert n == 0x0A
{n, 1} -> assert n == 0xA5
{n, 2} -> assert n == 0x50
{n, _} -> assert n == 0
end)
|> Enum.take(10)
end
test "shift = 5" do
[0xAA, 0x55]
|> BitShifter.ListWrapper.get(0)
|> BitShifter.get(5)
|> Stream.with_index()
|> Stream.map(fn
{n, 0} -> assert n == 0x05
{n, 1} -> assert n == 0x52
{n, 2} -> assert n == 0xA8
{n, _} -> assert n == 0
end)
|> Enum.take(10)
end
test "shift = 6" do
[0xAA, 0x55]
|> BitShifter.ListWrapper.get(0)
|> BitShifter.get(6)
|> Stream.with_index()
|> Stream.map(fn
{n, 0} -> assert n == 0x02
{n, 1} -> assert n == 0xA9
{n, 2} -> assert n == 0x54
{n, _} -> assert n == 0
end)
|> Enum.take(10)
end
test "shift = 7" do
[0xAA, 0x55]
|> BitShifter.ListWrapper.get(0)
|> BitShifter.get(7)
|> Stream.with_index()
|> Stream.map(fn
{n, 0} -> assert n == 0x01
{n, 1} -> assert n == 0x54
{n, 2} -> assert n == 0xAA
{n, _} -> assert n == 0
end)
|> Enum.take(10)
end
end
end
テストが通らないことを確認します.
次に実装します.
defmodule BitShifter do
@moduledoc """
A module to right shift an infinite length byte `Stream`.
"""
@spec get(Enumerable.t(), non_neg_integer()) :: Enumerable.t()
@doc """
Returns the result of shifting the given infinite length byte `stream`
by the given `shift` bits to the right.
"""
def get(stream, shift)
def get(stream, 0), do: stream
def get(stream, shift) when is_integer(shift) and 0 < shift and shift < 8 do
stream
|> Stream.scan([], &[&1 | &2])
|> Stream.drop(1)
|> Stream.map(fn [l | [h | _]] -> {h, l} end)
|> Stream.map(fn {h, l} -> h |> Bitwise.bsl(8) |> Bitwise.bor(l) end)
|> Stream.map(fn v ->
{
Bitwise.bsr(v, shift),
Bitwise.band(v, BitShifter.MaskFromShift.get(shift))
}
end)
|> Stream.map(fn {h, l} -> {Bitwise.bsr(h, 8), Bitwise.band(h, 0xFF), l} end)
|> Stream.scan([], &[&1 | &2])
|> Stream.with_index()
|> Stream.map(fn
{[l | _], 0} -> {{0, 0, 0}, l}
{[l | [h | _]], _} -> {h, l}
end)
|> Stream.map(fn {{_h1, m1, _l1}, {h2, _m2, _l2}} ->
Bitwise.bor(m1, h2)
end)
end
def get(stream, shift) when is_integer(shift) and 8 <= shift do
Stream.unfold(
{Bitwise.bsr(shift, 3), get(stream, Bitwise.band(shift, 7))},
fn
{0, s} -> {Enum.at(s, 0), {0, Stream.drop(s, 1)}}
{n, s} -> {0, {n - 1, s}}
end
)
end
end
テストが通りました!