LoginSignup
5
1

2024年になりました.新年最初のQiita記事として,2023年末にかなり苦労して実装した関数「Streamによる無限長のビット列の右シフト演算」をどのように実装したかを紹介します.

実装例

Elixirがインストールされているとします.まず,次のコマンドを打ち込みます.

mix new bit_shifter
cd bit_shifter
mkdir lib/bit_shifter
mkdir test/bit_shifter

前準備: BitShifter.ListWrapper

まずリストをStream化するに沿って ListWrapper モジュールを作りましょう.

テストコードからです.

test/bit_shifter/list_wrapper_test.exs
defmodule BitShifter.ListWrapperTest do
  use ExUnit.Case
  doctest BitShifter.ListWrapper

  test "get([])" do
    BitShifter.ListWrapper.get([])
    |> Stream.map(&assert &1 == 0)
    |> Enum.take(10)
  end

  test "get([1, 2])" do
    BitShifter.ListWrapper.get([1, 2])
    |> Stream.with_index()
    |> Stream.map(fn
      {v, 0} -> assert v == 1
      {v, 1} -> assert v == 2
      {v, _} -> assert v == 0
    end)
    |> Enum.take(10)
  end
end

当たり前ですが,テストは通りません.

bit_shifter% mix test
Compiling 1 file (.ex)
Generated bit_shifter app
    error: module BitShifter.ListWrapper is not loaded and could not be found
    │
  3 │   doctest BitShifter.ListWrapper
    │   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
    │
    └─ test/bit_shifter/list_wrapper_test.exs:3: BitShifter.ListWrapperTest (module)


== Compilation error in file test/bit_shifter/list_wrapper_test.exs ==
** (CompileError) test/bit_shifter/list_wrapper_test.exs: cannot compile module BitShifter.ListWrapperTest (errors have been logged)
    (ex_unit 1.16.0) expanding macro: ExUnit.DocTest.doctest/1
    test/bit_shifter/list_wrapper_test.exs:3: BitShifter.ListWrapperTest (module)
bit_shifter% 

次にモジュールBitShifter.ListWrapperを実装します.

lib/bit_shifter/list_wrapper.ex
defmodule BitShifter.ListWrapper do
  @moduledoc """
  A module to wrap a list with `Stream`.
  """

  @spec get(list(non_neg_integer()) | Enumerable.t(), non_neg_integer()) :: Enumerable.t()
  @doc """
  Returns the given `list` wrapped in a `Stream`
  such that the value following the list is `default_value`.
  """
  def get(list, default_value \\ 0)

  def get(list, default_value) when is_list(list) do
    Stream.unfold(list, fn
      [] -> {default_value, []}
      [h | t] -> {h, t}
    end)
  end

  def get(range, default_value) when is_struct(range, Range) do
    Stream.unfold(Enum.to_list(range), fn
      [] -> {default_value, []}
      [h | t] -> {h, t}
    end)
  end

  def get(stream, _default_value), do: stream
end

リストをStream化するでも例示したように,このコードそのものは,任意の値のリストlist,任意のデフォルト値default_valueを受け付けるのですが,ビット列を表すようにしたいので,Typespecsをキツめに,それぞれ,非負の整数値non_neg_integer()が入るようにしています.また,関数パターンマッチを使っています.本当は,関数パターンマッチで,Streamの場合だけ,素のまま返すとしたいのですが,うまくマッチする方法を見出せませんでした.

mix testすると,テストが通ります.

bit_shifter % mix test
....
Finished in 0.01 seconds (0.00s async, 0.01s sync)
1 doctest, 3 tests, 0 failures

本題: BitShifter

ここでようやく本題です.

shift = 0の場合

テストコードを作ります.

test/bit_shifter_test.exs
defmodule BitShifterTest do
  use ExUnit.Case
  doctest BitShifter

  describe "get" do
    test "shift = 0" do
      [0xAA, 0x55]
      |> BitShifter.ListWrapper.get(0)
      |> BitShifter.get(0)
      |> Stream.with_index()
      |> Stream.map(fn
        {n, 0} -> assert n == 0xAA
        {n, 1} -> assert n == 0x55
        {n, _} -> assert n == 0
      end)
      |> Enum.take(10)
    end
  end
end

mix testすると,エラーになることを確認します.まだ実装していないので,これでOKです.

bit_shifter % mix test
    warning: BitShifter.get/2 is undefined or private
    │
  9 │       |> BitShifter.get(0)
    │                     ~
    │
    └─ test/bit_shifter_test.exs:9:21: BitShifterTest."test get shift = 0"/1

.

  1) test get shift = 0 (BitShifterTest)
     test/bit_shifter_test.exs:6
     ** (UndefinedFunctionError) function BitShifter.get/2 is undefined or private
     code: |> BitShifter.get(0)
     stacktrace:
       (bit_shifter 0.1.0) BitShifter.get(#Function<63.53678557/2 in Stream.unfold/2>, 0)
       test/bit_shifter_test.exs:9: (test)

..
Finished in 0.01 seconds (0.00s async, 0.01s sync)
1 doctest, 3 tests, 1 failure

Randomized with seed 590452
bit_shifter % 

次に,lib/bit_shifter.exを次のように実装します.

lib/bit_shifter.ex
defmodule BitShifter do
  @moduledoc """
  A module to right shift an infinite length byte `Stream`.
  """

  @spec get(Enumerable.t(), non_neg_integer()) :: Enumerable.t()
  @doc """
  Returns the result of shifting the given infinite length byte `stream`
  by the given `shift` bits to the right.
  """
  def get(stream, shift)

  def get(stream, 0), do: stream
end

そうすると,mix testが通ります.

% mix test
...
Finished in 0.01 seconds (0.00s async, 0.01s sync)
3 tests, 0 failures

shift = 8, shift = 16 の場合

次に8ビットシフトと16ビットシフトを足します.

test/bit_shifter_test.exs
defmodule BitShifterTest do
  use ExUnit.Case
  doctest BitShifter

  describe "get" do
    test "shift = 0" do
      [0xAA, 0x55]
      |> BitShifter.ListWrapper.get(0)
      |> BitShifter.get(0)
      |> Stream.with_index()
      |> Stream.map(fn
        {n, 0} -> assert n == 0xAA
        {n, 1} -> assert n == 0x55
        {n, _} -> assert n == 0
      end)
      |> Enum.take(10)
    end

    test "shift = 8" do
      [0xAA, 0x55]
      |> BitShifter.ListWrapper.get(0)
      |> BitShifter.get(8)
      |> Stream.with_index()
      |> Stream.map(fn
        {n, 0} -> assert n == 0
        {n, 1} -> assert n == 0xAA
        {n, 2} -> assert n == 0x55
        {n, _} -> assert n == 0
      end)
      |> Enum.take(10)
    end

    test "shift = 16" do
      [0xAA, 0x55]
      |> BitShifter.ListWrapper.get(0)
      |> BitShifter.get(16)
      |> Stream.with_index()
      |> Stream.map(fn
        {n, 0} -> assert n == 0
        {n, 1} -> assert n == 0
        {n, 2} -> assert n == 0xAA
        {n, 3} -> assert n == 0x55
        {n, _} -> assert n == 0
      end)
      |> Enum.take(10)
    end
  end
end

この時,当然ながらテストは失敗します.

次に実装します.

lib/bit_shifter.ex
defmodule BitShifter do
  @moduledoc """
  A module to right shift an infinite length byte `Stream`.
  """

  @spec get(Enumerable.t(), non_neg_integer()) :: Enumerable.t()
  @doc """
  Returns the result of shifting the given infinite length byte `stream`
  by the given `shift` bits to the right.
  """
  def get(stream, shift)

  def get(stream, 0), do: stream

  def get(stream, shift) when is_integer(shift) and 8 <= shift do
    Stream.unfold(
      {Bitwise.bsr(shift, 3), get(stream, Bitwise.band(shift, 7))},
      fn
        {0, s} -> {Enum.at(s, 0), {0, Stream.drop(s, 1)}}
        {n, s} -> {0, {n - 1, s}}
      end
    )
  end
end

これで,テストが通ることを確認します.

シフト値からマスクを生成する関数BitShifter.MaskFromShift.get/1を作成する

いよいよshiftが任意の場合にチャレンジするわけですが,その前に次のような関数BitShifter.MaskFromShift.get/1を作ります.

テストコードで仕様を見ていきます.

test/bit_shifter/mask_from_shift_test.exs
defmodule BitShifter.MaskFromShiftTest do
  use ExUnit.Case
  doctest BitShifter.MaskFromShift

  describe "get" do
    test "0" do
      assert BitShifter.MaskFromShift.get(0) == 0
    end

    test "1" do
      assert BitShifter.MaskFromShift.get(1) == 1
    end

    test "2" do
      assert BitShifter.MaskFromShift.get(2) == 3
    end

    test "3" do
      assert BitShifter.MaskFromShift.get(3) == 7
    end

    test "4" do
      assert BitShifter.MaskFromShift.get(4) == 0xF
    end

    test "5" do
      assert BitShifter.MaskFromShift.get(5) == 0x1F
    end

    test "6" do
      assert BitShifter.MaskFromShift.get(6) == 0x3F
    end

    test "7" do
      assert BitShifter.MaskFromShift.get(7) == 0x7F
    end
  end
end

実装は次のようになります.

lib/bit_shifter/mask_from_shift.ex
defmodule BitShifter.MaskFromShift do
  @moduledoc """
  A module that generates mask bits corresponding to the shift value.
  """

  @spec get(non_neg_integer()) :: non_neg_integer()
  @doc """
  Returns mask bits corresponding to the given `shift` value.
  """
  def get(shift)

  def get(0), do: 0

  def get(shift) when is_number(shift) and 0 < shift and shift < 8 do
    shift = shift - 1

    Enum.reduce(0..shift, 0, fn _, acc ->
      acc
      |> Bitwise.bsl(1)
      |> Bitwise.bor(1)
    end)
  end
end

shiftを8で割った時の余りが0以外の場合

次がいよいよ山場です.

テストを実装します.

test/bit_shifter_test.exs
defmodule BitShifterTest do
  use ExUnit.Case
  doctest BitShifter

  describe "get" do
    test "shift = 0" do
      [0xAA, 0x55]
      |> BitShifter.ListWrapper.get(0)
      |> BitShifter.get(0)
      |> Stream.with_index()
      |> Stream.map(fn
        {n, 0} -> assert n == 0xAA
        {n, 1} -> assert n == 0x55
        {n, _} -> assert n == 0
      end)
      |> Enum.take(10)
    end

    test "shift = 8" do
      [0xAA, 0x55]
      |> BitShifter.ListWrapper.get(0)
      |> BitShifter.get(8)
      |> Stream.with_index()
      |> Stream.map(fn
        {n, 0} -> assert n == 0
        {n, 1} -> assert n == 0xAA
        {n, 2} -> assert n == 0x55
        {n, _} -> assert n == 0
      end)
      |> Enum.take(10)
    end

    test "shift = 16" do
      [0xAA, 0x55]
      |> BitShifter.ListWrapper.get(0)
      |> BitShifter.get(16)
      |> Stream.with_index()
      |> Stream.map(fn
        {n, 0} -> assert n == 0
        {n, 1} -> assert n == 0
        {n, 2} -> assert n == 0xAA
        {n, 3} -> assert n == 0x55
        {n, _} -> assert n == 0
      end)
      |> Enum.take(10)
    end

    test "shift = 1" do
      [0xAA, 0x55]
      |> BitShifter.ListWrapper.get(0)
      |> BitShifter.get(1)
      |> Stream.with_index()
      |> Stream.map(fn
        {n, 0} -> assert n == 0x55
        {n, 1} -> assert n == 0x2A
        {n, 2} -> assert n == 0x80
        {n, _} -> assert n == 0
      end)
      |> Enum.take(10)
    end

    test "shift = 2" do
      [0xAA, 0x55]
      |> BitShifter.ListWrapper.get(0)
      |> BitShifter.get(2)
      |> Stream.with_index()
      |> Stream.map(fn
        {n, 0} -> assert n == 0x2A
        {n, 1} -> assert n == 0x95
        {n, 2} -> assert n == 0x40
        {n, _} -> assert n == 0
      end)
      |> Enum.take(10)
    end

    test "shift = 3" do
      [0xAA, 0x55]
      |> BitShifter.ListWrapper.get(0)
      |> BitShifter.get(3)
      |> Stream.with_index()
      |> Stream.map(fn
        {n, 0} -> assert n == 0x15
        {n, 1} -> assert n == 0x4A
        {n, 2} -> assert n == 0xA0
        {n, _} -> assert n == 0
      end)
      |> Enum.take(10)
    end

    test "shift = 4" do
      [0xAA, 0x55]
      |> BitShifter.ListWrapper.get(0)
      |> BitShifter.get(4)
      |> Stream.with_index()
      |> Stream.map(fn
        {n, 0} -> assert n == 0x0A
        {n, 1} -> assert n == 0xA5
        {n, 2} -> assert n == 0x50
        {n, _} -> assert n == 0
      end)
      |> Enum.take(10)
    end

    test "shift = 5" do
      [0xAA, 0x55]
      |> BitShifter.ListWrapper.get(0)
      |> BitShifter.get(5)
      |> Stream.with_index()
      |> Stream.map(fn
        {n, 0} -> assert n == 0x05
        {n, 1} -> assert n == 0x52
        {n, 2} -> assert n == 0xA8
        {n, _} -> assert n == 0
      end)
      |> Enum.take(10)
    end

    test "shift = 6" do
      [0xAA, 0x55]
      |> BitShifter.ListWrapper.get(0)
      |> BitShifter.get(6)
      |> Stream.with_index()
      |> Stream.map(fn
        {n, 0} -> assert n == 0x02
        {n, 1} -> assert n == 0xA9
        {n, 2} -> assert n == 0x54
        {n, _} -> assert n == 0
      end)
      |> Enum.take(10)
    end

    test "shift = 7" do
      [0xAA, 0x55]
      |> BitShifter.ListWrapper.get(0)
      |> BitShifter.get(7)
      |> Stream.with_index()
      |> Stream.map(fn
        {n, 0} -> assert n == 0x01
        {n, 1} -> assert n == 0x54
        {n, 2} -> assert n == 0xAA
        {n, _} -> assert n == 0
      end)
      |> Enum.take(10)
    end
  end
end

テストが通らないことを確認します.

次に実装します.

lib/bit_shifter.ex
defmodule BitShifter do
  @moduledoc """
  A module to right shift an infinite length byte `Stream`.
  """

  @spec get(Enumerable.t(), non_neg_integer()) :: Enumerable.t()
  @doc """
  Returns the result of shifting the given infinite length byte `stream`
  by the given `shift` bits to the right.
  """
  def get(stream, shift)

  def get(stream, 0), do: stream

  def get(stream, shift) when is_integer(shift) and 0 < shift and shift < 8 do
    stream
    |> Stream.scan([], &[&1 | &2])
    |> Stream.drop(1)
    |> Stream.map(fn [l | [h | _]] -> {h, l} end)
    |> Stream.map(fn {h, l} -> h |> Bitwise.bsl(8) |> Bitwise.bor(l) end)
    |> Stream.map(fn v ->
      {
        Bitwise.bsr(v, shift),
        Bitwise.band(v, BitShifter.MaskFromShift.get(shift))
      }
    end)
    |> Stream.map(fn {h, l} -> {Bitwise.bsr(h, 8), Bitwise.band(h, 0xFF), l} end)
    |> Stream.scan([], &[&1 | &2])
    |> Stream.with_index()
    |> Stream.map(fn
      {[l | _], 0} -> {{0, 0, 0}, l}
      {[l | [h | _]], _} -> {h, l}
    end)
    |> Stream.map(fn {{_h1, m1, _l1}, {h2, _m2, _l2}} ->
      Bitwise.bor(m1, h2)
    end)
  end

  def get(stream, shift) when is_integer(shift) and 8 <= shift do
    Stream.unfold(
      {Bitwise.bsr(shift, 3), get(stream, Bitwise.band(shift, 7))},
      fn
        {0, s} -> {Enum.at(s, 0), {0, Stream.drop(s, 1)}}
        {n, s} -> {0, {n - 1, s}}
      end
    )
  end
end

テストが通りました!

5
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
1