LoginSignup
2
0

More than 3 years have passed since last update.

Fortran で tree map を作ってみた

Last updated at Posted at 2019-07-24

動機

例のごとく、 Fortran には連想配列が実装されていないので自分で実装します。連想配列の実装は主にハッシュテーブルと平行二分探索木に分かれますが、今回は平行二分探索木 (AA 木 (Wikipedia)) を用いて実装します。はじめは赤黒木で実装しようと思ったのですが、実装が重く、バグが取れなかったので、比較的実装が簡単かつ赤黒木と同等な性能を持つ AA 木で実装しました。 Java の TreeMap に備わっているメソッドを参考にしました。

tree map とは

連想配列 (map) は要素をキーと値の組で管理するデータ構造です。簡単に言うと、インデックスとして連続した整数だけではなく、連続していない整数や小数、文字列も使える配列のようなものです。 1 つのキーに対して組にできる値は 1 個で、キーは重複して保持することはできません。 map にすでに存在するキーに対して新たに値を挿入する場合は、それまでに保持していた値を上書きすることになります。 連想配列は Java では Map 、 Go では map 、 Python では dict 、 Perl では hash として実装されています。連想配列の中で tree map は要素をキーの順序を保ってデータを保持する特徴があります。そのため、キーが最小・最大の要素や、指定したキーに一番近いキーの要素などを高速に取り出すことができます。 tree map は平行二分探索木を用いて実装されるため、 tree map が保持している要素数を $N$ とすると、要素の挿入、削除、検索、上書きの速度はどれも $O(\log N)$ となります。

AA 木について

AA 木とは平行二分探索木の一種で、 1993 年に Arne Andersson 氏が発表したので、イニシャルをとって AA 木と呼ばれています。コードの記述が短くシンプルになるように、アルゴリズムは他の平行二分探索木に比べてシンプルになっています。要素を挿入、削除する際に木の平衡性を保つための操作は skew と split の 2 つのみです。 Java の TreeMap で用いられている赤黒木も平行二分探索木の一種ですが、こちらは平衡性を保つために 4 パターン以上の場合分けをしなければならず、実装が大変重いです。しかしながら、赤黒木は最悪時間計算量が $O(\log N)$ と短いため、様々な場所で用いられています。 AA 木は赤黒木に比べて、平衡性を保つ操作を行う回数が多いですが、アルゴリズムが単純なので高速なため赤黒木と同等な性能を持つと言われており、また、実装も軽いです。 (参考: Balanced Search Trees Made Simple. Arne Andersson, 1993)

環境

macOS Sierra (10.12.6)
GNU Fortran (GCC) 6.3.0

プログラム

実際に実装したのが以下のものです。今回のコードはこれまでに比べて長くなりました。ここで実装している tree map はキーと値が integer のものです。

追記 2019/09/22

decrease_level() に不備が見つかったので、一部書き換えました。それに伴い、 level() を追加しました。

追記 2019/09/26

delete() に不備が見つかったので、一部書き換えました。

tree_map.f08
module mod_tree_map

  ! 要素
  type t_entry
    integer :: key ! キー
    integer :: val ! 値
  end type t_entry

  ! 木のノード
  type t_node
    type(t_node), pointer :: left => null(), right => null() ! 子ノード
    integer :: level = 1                                     ! ノードの階層
    type(t_entry), pointer :: e => null()                    ! 保持する要素
  end type t_node

  type t_tree_map
    type(t_node), pointer :: root => null() ! 木の根ノード
    integer :: deflt = -1                   ! キーが存在しない場合のデフォルト値
  end type t_tree_map

contains

  ! 要素のコンストラクタ
  function new_entry(key,val) result(e)
    implicit none
    integer, intent(in) :: key
    integer, intent(in) :: val
    type(t_entry), pointer :: e

    e => null()
    allocate(e)
    e%key = key
    e%val = val
    return
  end function new_entry

  ! ノードのコンストラクタ
  function new_node(e) result(n)
    implicit none
    type(t_entry), pointer, intent(in) :: e
    type(t_node), pointer :: n

    n => null()
    allocate(n)
    n%e => e
    return
  end function new_node

  ! ノードの階層
  integer function level(n)
    type(t_node), pointer, intent(in) :: n

    level = 0
    if (.not.associated(n)) return
    level = n%level
  end

  ! ノードが葉であるか
  logical function is_leaf(n)
    implicit none
    type(t_node), pointer, intent(in) :: n

    is_leaf = associated(n) .and. .not.associated(n%left) .and. &
    &         .not.associated(n%right)
    return
  end function is_leaf

  ! 木のサイズ
  recursive function tree_size(n) result(s)
    implicit none
    type(t_node), pointer, intent(in) :: n
    integer :: s

    s = 0
    if (.not.associated(n)) return
    s = 1+tree_size(n%left)+tree_size(n%right)
    return
  end function tree_size

  ! 木の平衡性を保つ操作
  function skew(n) result(l)
    implicit none
    type(t_node), pointer, intent(in) :: n
    type(t_node), pointer :: l

    l => n
    if (.not.(associated(n) .and. associated(n%left) .and. &
    &   n%left%level == n%level)) return
    l => n%left
    n%left => l%right
    l%right => n
    return
  end function skew

  ! 木の平衡性を保つ操作
  function split(n) result(r)
    implicit none
    type(t_node), pointer, intent(in) :: n
    type(t_node), pointer :: r

    r => n
    if (.not.(associated(n) .and. associated(n%right) .and. &
    &   associated(n%right%right) .and. n%right%right%level == n%level)) return
    r => n%right
    n%right => r%left
    r%left => n
    r%level = r%level+1
    return
  end function split

  ! 最も近い先行ノード
  function predecessor(n) result(p)
    implicit none
    type(t_node), pointer, intent(in) :: n
    type(t_node), pointer :: p

    p => null()
    if (.not.associated(n%left)) return
    p => n%left
    do while (associated(p%right))
      p => p%right
    end do
    return
  end function predecessor

  ! 最も近い後続ノード
  function successor(n) result(s)
    implicit none
    type(t_node), pointer, intent(in) :: n
    type(t_node), pointer :: s

    s => null()
    if (.not.associated(n%right)) return
    s => n%right
    do while (associated(s%left))
      s => s%left
    end do
    return
  end function successor

  ! 要素の挿入
  recursive function insert(n,e) result(t)
    implicit none
    type(t_node), pointer, intent(in) :: n
    type(t_entry), pointer, intent(in) :: e
    type(t_node), pointer :: t

    t => new_node(e)
    if (.not.associated(n)) return
    t => n
    if (e%key < t%e%key) then
      t%left => insert(t%left,e)
    else if (e%key > t%e%key) then
      t%right => insert(t%right,e)
    else
      t%e => e
    end if

    t => skew(t)
    t => split(t)
    return
  end function insert

  ! 要素の削除
  recursive function delete(n,e) result(t)
    implicit none
    type(t_node), pointer, intent(in) :: n
    type(t_entry), pointer, intent(in) :: e
    type(t_node), pointer :: t, l

    t => n
    if (.not.associated(n)) return
    if (e%key < t%e%key) then
      t%left => delete(t%left,e)
    else if (e%key > t%e%key) then
      t%right => delete(t%right,e)
    else
      if (is_leaf(t)) then ! changed 2019/09/26
        t => null()
        return
      end if
      if (.not.associated(t%left)) then
        l => successor(t)
        t%right => delete(t%right,l%e)
        t%e => l%e
      else
        l => predecessor(t)
        t%left => delete(t%left,l%e)
        t%e => l%e
      end if
    end if

    t => decrease_level(t)
    t => skew(t)
    t%right => skew(t%right)
    if (associated(t%right)) t%right%right => skew(t%right%right)
    t => split(t)
    t%right => split(t%right)
    return
  end function delete

  ! 削除の操作でノードの階層が変わるときの操作
  function decrease_level(n) result(t)
    type(t_node), pointer, intent(in) :: n
    type(t_node), pointer :: t
    integer :: should_be

    t => n
    should_be = min(level(t%left),level(t%right))+1
    if (t%level > should_be) then
      t%level = should_be
      if (level(t%right) > should_be) t%right%level = should_be
    end if
  end

  ! 木の全要素の削除
  recursive subroutine release_tree(t)
    implicit none
    type(t_node), pointer, intent(inout) :: t

    if (.not.associated(t)) return
    call release_tree(t%left)
    call release_tree(t%right)
    deallocate(t)
    return
  end subroutine release_tree

  ! 木の全要素のキーの取得
  recursive subroutine get_keys_list(t,keys,num)
    implicit none
    type(t_node), pointer, intent(in) :: t
    integer, intent(inout) :: keys(:)
    integer, intent(inout) :: num

    if (.not.associated(t)) return
    call get_keys_list(t%left,keys,num)
    num = num+1
    keys(num) = t%e%key
    call get_keys_list(t%right,keys,num)
    return
  end subroutine get_keys_list

  ! マップのサイズ
  integer function size_of(map)
    implicit none
    type(t_tree_map), intent(in) :: map

    size_of = tree_size(map%root)
    return
  end function size_of

  ! マップの全要素の削除
  subroutine clear(map)
    implicit none
    type(t_tree_map), intent(inout) :: map

    call release_tree(map%root)
    map%root => null()
    return
  end subroutine clear

  ! マップのデフォルト値の変更
  subroutine set_default(map,deflt)
    implicit none
    type(t_tree_map), intent(inout) :: map
    integer, intent(in) :: deflt

    map%deflt = deflt
    return
  end subroutine set_default

  ! マップへの要素の挿入
  subroutine put_entry(map,e)
    implicit none
    type(t_tree_map), intent(inout) :: map
    type(t_entry), pointer, intent(in) :: e

    map%root => insert(map%root,e)
    return
  end subroutine put_entry

  ! マップから要素の削除
  subroutine remove_entry(map,e)
    implicit none
    type(t_tree_map), intent(inout) :: map
    type(t_entry), pointer, intent(in) :: e

    map%root => delete(map%root,e)
    return
  end subroutine remove_entry

  ! マップから要素の取得
  function get_entry(map,e) result(ret)
    implicit none
    type(t_tree_map), intent(inout) :: map
    type(t_entry), pointer, intent(in) :: e
    type(t_node), pointer :: n
    type(t_entry), pointer :: ret

    ret => null()
    n => map%root
    do while (associated(n))
      if (e%key < n%e%key) then
        n => n%left
      else if (e%key > n%e%key) then
        n => n%right
      else
        ret => n%e
        return
      end if
    end do
    return
  end function get_entry

  ! マップが要素を保持しているか
  function contain_entry(map,e) result(ret)
    implicit none
    type(t_tree_map), intent(inout) :: map
    type(t_entry), pointer, intent(in) :: e
    type(t_node), pointer :: n
    logical :: ret

    ret = .false.
    n => map%root
    do while (associated(n))
      if (e%key < n%e%key) then
        n => n%left
      else if (e%key > n%e%key) then
        n => n%right
      else
        ret = .true.
        return
      end if
    end do
    return
  end function contain_entry

  ! キーが最小の要素の取得
  function get_first_entry(map) result(ret)
    implicit none
    type(t_tree_map), intent(inout) :: map
    type(t_node), pointer :: n
    type(t_entry), pointer :: ret

    ret => null()
    n => map%root
    if (.not.associated(n)) return
    do while (associated(n%left))
      n => n%left
    end do
    ret => n%e
    return
  end function get_first_entry

  ! キーが最小の要素の削除
  function poll_first_entry(map) result(ret)
    implicit none
    type(t_tree_map), intent(inout) :: map
    type(t_node), pointer :: n
    type(t_entry), pointer :: ret

    ret => null()
    n => map%root
    if (.not.associated(n)) return
    do while (associated(n%left))
      n => n%left
    end do
    ret => n%e

    map%root => delete(map%root,ret)
    return
  end function poll_first_entry

  ! キーが最大の要素の取得
  function get_last_entry(map) result(ret)
    implicit none
    type(t_tree_map), intent(inout) :: map
    type(t_node), pointer :: n
    type(t_entry), pointer :: ret

    ret => null()
    n => map%root
    if (.not.associated(n)) return
    do while (associated(n%right))
      n => n%right
    end do
    ret => n%e
    return
  end function get_last_entry

  ! キーが最大の要素の削除
  function poll_last_entry(map) result(ret)
    implicit none
    type(t_tree_map), intent(inout) :: map
    type(t_node), pointer :: n
    type(t_entry), pointer :: ret

    ret => null()
    n => map%root
    if (.not.associated(n)) return
    do while (associated(n%right))
      n => n%right
    end do
    ret => n%e

    map%root => delete(map%root,ret)
    return
  end function poll_last_entry

  ! キーが指定した要素のキー以下の要素の取得
  function floor_entry(map,e) result(ret)
    implicit none
    type(t_tree_map), intent(inout) :: map
    type(t_entry), pointer, intent(in) :: e
    type(t_node), pointer :: n
    type(t_entry), pointer :: ret

    ret => null()
    n => map%root
    do while (associated(n))
      if (e%key < n%e%key) then
        n => n%left
      else if (e%key > n%e%key) then
        if (.not.associated(ret)) then
          ret => n%e
          cycle
        end if
        if (e%key-ret%key > e%key-n%e%key) ret => n%e
        n => n%right
      else
        ret => n%e
        return
      end if
    end do
    return
  end function floor_entry

  ! キーが指定した要素のキーより小さいの要素の取得
  function lower_entry(map,e) result(ret)
    implicit none
    type(t_tree_map), intent(inout) :: map
    type(t_entry), pointer, intent(in) :: e
    type(t_entry), pointer :: ret

    ret => floor_entry(map,new_entry(e%key-1,0))
    return
  end function lower_entry

  ! キーが指定した要素のキー以上の要素の取得
  function ceiling_entry(map,e) result(ret)
    implicit none
    type(t_tree_map), intent(inout) :: map
    type(t_entry), pointer, intent(in) :: e
    type(t_node), pointer :: n
    type(t_entry), pointer :: ret

    ret => null()
    n => map%root
    do while (associated(n))
      if (e%key < n%e%key) then
        if (.not.associated(ret)) then
          ret => n%e
          cycle
        end if
        if (e%key-ret%key < e%key-n%e%key) ret => n%e
        n => n%left
      else if (e%key > n%e%key) then
        n => n%right
      else
        ret => n%e
        return
      end if
    end do
    return
  end function ceiling_entry

  ! キーが指定した要素のキーより大きいの要素の取得
  function higher_entry(map,e) result(ret)
    implicit none
    type(t_tree_map), intent(inout) :: map
    type(t_entry), pointer, intent(in) :: e
    type(t_entry), pointer :: ret

    ret => ceiling_entry(map,new_entry(e%key+1,0))
    return
  end function higher_entry

  ! マップの全ての要素のキーの取得
  subroutine get_keys(map,keys,num)
    implicit none
    type(t_tree_map), intent(inout) :: map
    integer, intent(inout) :: keys(:)
    integer, intent(inout) :: num

    keys = 0
    num = 0
    call get_keys_list(map%root,keys,num)
    return
  end subroutine get_keys

  ! キーと値の指定によるマップへの要素の挿入
  subroutine put(map,key,val)
    implicit none
    type(t_tree_map), intent(inout) :: map
    integer, intent(in) :: key
    integer, intent(in) :: val

    call put_entry(map,new_entry(key,val))
    return
  end subroutine put

  ! キーの指定によるマップから要素の削除
  subroutine remove(map,key)
    implicit none
    type(t_tree_map), intent(inout) :: map
    integer, intent(in) :: key

    call remove_entry(map,new_entry(key,0))
    return
  end subroutine remove

  ! キーの指定によるマップから要素の取得
  function get(map,key) result(val)
    implicit none
    type(t_tree_map), intent(inout) :: map
    integer, intent(in) :: key
    type(t_entry), pointer :: tmp
    integer :: val

    val = map%deflt
    tmp => get_entry(map,new_entry(key,0))
    if (.not.associated(tmp)) return
    val = tmp%val
    return
  end function get

  ! キーの指定によるマップが要素を保持しているか
  logical function contain(map,key)
    implicit none
    type(t_tree_map), intent(inout) :: map
    integer, intent(in) :: key

    contain = contain_entry(map,new_entry(key,0))
    return
  end function contain

  ! 最小のキーの取得
  function get_first_key(map) result(key)
    implicit none
    type(t_tree_map), intent(inout) :: map
    type(t_entry), pointer :: tmp
    integer :: key

    key = map%deflt
    tmp => get_first_entry(map)
    if (.not.associated(tmp)) return
    key = tmp%key
    return
  end function get_first_key

  ! 最大のキーの取得
  function get_last_key(map) result(key)
    implicit none
    type(t_tree_map), intent(inout) :: map
    type(t_entry), pointer :: tmp
    integer :: key

    key = map%deflt
    tmp => get_last_entry(map)
    if (.not.associated(tmp)) return
    key = tmp%key
    return
  end function get_last_key

  ! 指定したキー以下のキーの取得
  function floor_key(map,key) result(ret)
    implicit none
    type(t_tree_map), intent(inout) :: map
    integer, intent(in) :: key
    type(t_entry), pointer :: tmp
    integer :: ret

    ret = map%deflt
    tmp => floor_entry(map,new_entry(key,0))
    if (.not.associated(tmp)) return
    ret = tmp%key
    return
  end function floor_key

  ! 指定したキーより小さいのキーの取得
  function lower_key(map,key) result(ret)
    implicit none
    type(t_tree_map), intent(inout) :: map
    integer, intent(in) :: key
    type(t_entry), pointer :: tmp
    integer :: ret

    ret = map%deflt
    tmp => lower_entry(map,new_entry(key,0))
    if (.not.associated(tmp)) return
    ret = tmp%key
    return
  end function lower_key

  ! 指定したキー以上のキーの取得
  function ceiling_key(map,key) result(ret)
    implicit none
    type(t_tree_map), intent(inout) :: map
    integer, intent(in) :: key
    type(t_entry), pointer :: tmp
    integer :: ret

    ret = map%deflt
    tmp => ceiling_entry(map,new_entry(key,0))
    if (.not.associated(tmp)) return
    ret = tmp%key
    return
  end function ceiling_key

  ! 指定したキーより大きいキーの取得
  function higher_key(map,key) result(ret)
    implicit none
    type(t_tree_map), intent(inout) :: map
    integer, intent(in) :: key
    type(t_entry), pointer :: tmp
    integer :: ret

    ret = map%deflt
    tmp => higher_entry(map,new_entry(key,0))
    if (.not.associated(tmp)) return
    ret = tmp%key
    return
  end function higher_key

end module mod_tree_map

また、実装できているか確認するために以下のプログラムも書きました。

test_tree_map.f08
program test_tree_map
  use mod_tree_map
  implicit none
  type(t_tree_map) :: map
  integer :: keys(100), vals(100), num, i

  call put(map,1,10)
  call put(map,8,80)
  call put(map,4,40)
  call put(map,2,20)
  call put(map,5,50)
  call put(map,9,90)
  call put(map,3,30)
  call put(map,6,60)
  call put(map,7,70)

  write(*,'(a)',advance='no') "keys: "
  call get_keys(map,keys,num)
  call output(keys(1:num))

  write(*,'(a)',advance='no') "vals: "
  do i = 1, num
    vals(i) = get(map,keys(i))
  end do
  call output(vals(1:num))

  write(*,'("key: ",i0," val: ",i0)') 10, get(map,10)
  write(*,'("first key: ",i0)') get_first_key(map)
  write(*,'("last key: ",i0)') get_last_key(map)
  write(*,'("lower key of ",i0,": ",i0)') 5, lower_key(map,5)
  write(*,'("floor key of ",i0,": ",i0)') 5, floor_key(map,5)
  write(*,'("ceiling key of ",i0,": ",i0)') 5, ceiling_key(map,5)
  write(*,'("higher key of ",i0,": ",i0)') 5, higher_key(map,5)

  call put(map,5,500)
  write(*,'("key: ",i0," val: ",i0)') 5, get(map,5)

  write(*,'("size: ",i0)') size_of(map)
  call clear(map)
  write(*,'("size: ",i0)') size_of(map)

  call put(map,1,10)
  call put(map,5,50)
  call put(map,8,80)
  call put(map,3,30)
  call put(map,7,70)
  call put(map,4,40)
  call put(map,2,20)
  call put(map,9,90)
  call put(map,6,60)

  call set_default(map,-1000000)
  write(*,'("key: ",i0," val: ",i0)') 5, get(map,5)
  write(*,'("key: ",i0," val: ",i0)') 10, get(map,10)

  write(*,'(a)',advance='no') "keys: "
  call get_keys(map,keys,num)
  call output(keys(1:num))

  write(*,'(a)',advance='no') "vals: "
  do i = 1, num
    vals(i) = get(map,keys(i))
  end do
  call output(vals(1:num))

  write(*,'("lower key of ",i0,": ",i0)') 1, lower_key(map,1)
  write(*,'("floor key of ",i0,": ",i0)') 1, floor_key(map,1)
  write(*,'("ceiling key of ",i0,": ",i0)') 9, ceiling_key(map,9)
  write(*,'("higher key of ",i0,": ",i0)') 9, higher_key(map,9)

  stop
contains

  subroutine output(a)
    implicit none
    integer, intent(in) :: a(:)
    integer :: n, i
    n = size(a)
    do i = 1, n
      write(*,'(i0,x)',advance='no') a(i)
    end do
    write(*,*)
    return
  end subroutine output

end program test_tree_map

実行結果は以下のとおりです。

keys: 1 2 3 4 5 6 7 8 9
vals: 10 20 30 40 50 60 70 80 90
key: 10 val: -1
first key: 1
last key: 9
lower key of 5: 4
floor key of 5: 5
ceiling key of 5: 5
higher key of 5: 6
key: 5 val: 500
size: 9
size: 0
key: 5 val: 50
key: 10 val: -1000000
keys: 1 2 3 4 5 6 7 8 9
vals: 10 20 30 40 50 60 70 80 90
lower key of 1: -1000000
floor key of 1: 1
ceiling key of 9: 9
higher key of 9: -1000000

これで Fortran でも連想配列を使うことができるようになりました。 map を実装すると、同時に set も実装できていることになるので、これで Fortran で set が必要になっても安心です。ただ、キーと値の型ごとに作り直さなければならないのが面倒ですね。

追記 2019/08/11

不要な部分を削り、少し短くなりました (ソースコード) 。

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0