動機
例のごとく、 Fortran には連想配列が実装されていないので自分で実装します。連想配列の実装は主にハッシュテーブルと平行二分探索木に分かれますが、今回は平行二分探索木 (AA 木 (Wikipedia)) を用いて実装します。はじめは赤黒木で実装しようと思ったのですが、実装が重く、バグが取れなかったので、比較的実装が簡単かつ赤黒木と同等な性能を持つ AA 木で実装しました。 Java の TreeMap に備わっているメソッドを参考にしました。
tree map とは
連想配列 (map) は要素をキーと値の組で管理するデータ構造です。簡単に言うと、インデックスとして連続した整数だけではなく、連続していない整数や小数、文字列も使える配列のようなものです。 1 つのキーに対して組にできる値は 1 個で、キーは重複して保持することはできません。 map にすでに存在するキーに対して新たに値を挿入する場合は、それまでに保持していた値を上書きすることになります。 連想配列は Java では Map 、 Go では map 、 Python では dict 、 Perl では hash として実装されています。連想配列の中で tree map は要素をキーの順序を保ってデータを保持する特徴があります。そのため、キーが最小・最大の要素や、指定したキーに一番近いキーの要素などを高速に取り出すことができます。 tree map は平行二分探索木を用いて実装されるため、 tree map が保持している要素数を $N$ とすると、要素の挿入、削除、検索、上書きの速度はどれも $O(\log N)$ となります。
AA 木について
AA 木とは平行二分探索木の一種で、 1993 年に Arne Andersson 氏が発表したので、イニシャルをとって AA 木と呼ばれています。コードの記述が短くシンプルになるように、アルゴリズムは他の平行二分探索木に比べてシンプルになっています。要素を挿入、削除する際に木の平衡性を保つための操作は skew と split の 2 つのみです。 Java の TreeMap で用いられている赤黒木も平行二分探索木の一種ですが、こちらは平衡性を保つために 4 パターン以上の場合分けをしなければならず、実装が大変重いです。しかしながら、赤黒木は最悪時間計算量が $O(\log N)$ と短いため、様々な場所で用いられています。 AA 木は赤黒木に比べて、平衡性を保つ操作を行う回数が多いですが、アルゴリズムが単純なので高速なため赤黒木と同等な性能を持つと言われており、また、実装も軽いです。 (参考: Balanced Search Trees Made Simple. Arne Andersson, 1993)
環境
macOS Sierra (10.12.6)
GNU Fortran (GCC) 6.3.0
プログラム
実際に実装したのが以下のものです。今回のコードはこれまでに比べて長くなりました。ここで実装している tree map はキーと値が integer のものです。
追記 2019/09/22
decrease_level()
に不備が見つかったので、一部書き換えました。それに伴い、 level()
を追加しました。
追記 2019/09/26
delete()
に不備が見つかったので、一部書き換えました。
module mod_tree_map
! 要素
type t_entry
integer :: key ! キー
integer :: val ! 値
end type t_entry
! 木のノード
type t_node
type(t_node), pointer :: left => null(), right => null() ! 子ノード
integer :: level = 1 ! ノードの階層
type(t_entry), pointer :: e => null() ! 保持する要素
end type t_node
type t_tree_map
type(t_node), pointer :: root => null() ! 木の根ノード
integer :: deflt = -1 ! キーが存在しない場合のデフォルト値
end type t_tree_map
contains
! 要素のコンストラクタ
function new_entry(key,val) result(e)
implicit none
integer, intent(in) :: key
integer, intent(in) :: val
type(t_entry), pointer :: e
e => null()
allocate(e)
e%key = key
e%val = val
return
end function new_entry
! ノードのコンストラクタ
function new_node(e) result(n)
implicit none
type(t_entry), pointer, intent(in) :: e
type(t_node), pointer :: n
n => null()
allocate(n)
n%e => e
return
end function new_node
! ノードの階層
integer function level(n)
type(t_node), pointer, intent(in) :: n
level = 0
if (.not.associated(n)) return
level = n%level
end
! ノードが葉であるか
logical function is_leaf(n)
implicit none
type(t_node), pointer, intent(in) :: n
is_leaf = associated(n) .and. .not.associated(n%left) .and. &
& .not.associated(n%right)
return
end function is_leaf
! 木のサイズ
recursive function tree_size(n) result(s)
implicit none
type(t_node), pointer, intent(in) :: n
integer :: s
s = 0
if (.not.associated(n)) return
s = 1+tree_size(n%left)+tree_size(n%right)
return
end function tree_size
! 木の平衡性を保つ操作
function skew(n) result(l)
implicit none
type(t_node), pointer, intent(in) :: n
type(t_node), pointer :: l
l => n
if (.not.(associated(n) .and. associated(n%left) .and. &
& n%left%level == n%level)) return
l => n%left
n%left => l%right
l%right => n
return
end function skew
! 木の平衡性を保つ操作
function split(n) result(r)
implicit none
type(t_node), pointer, intent(in) :: n
type(t_node), pointer :: r
r => n
if (.not.(associated(n) .and. associated(n%right) .and. &
& associated(n%right%right) .and. n%right%right%level == n%level)) return
r => n%right
n%right => r%left
r%left => n
r%level = r%level+1
return
end function split
! 最も近い先行ノード
function predecessor(n) result(p)
implicit none
type(t_node), pointer, intent(in) :: n
type(t_node), pointer :: p
p => null()
if (.not.associated(n%left)) return
p => n%left
do while (associated(p%right))
p => p%right
end do
return
end function predecessor
! 最も近い後続ノード
function successor(n) result(s)
implicit none
type(t_node), pointer, intent(in) :: n
type(t_node), pointer :: s
s => null()
if (.not.associated(n%right)) return
s => n%right
do while (associated(s%left))
s => s%left
end do
return
end function successor
! 要素の挿入
recursive function insert(n,e) result(t)
implicit none
type(t_node), pointer, intent(in) :: n
type(t_entry), pointer, intent(in) :: e
type(t_node), pointer :: t
t => new_node(e)
if (.not.associated(n)) return
t => n
if (e%key < t%e%key) then
t%left => insert(t%left,e)
else if (e%key > t%e%key) then
t%right => insert(t%right,e)
else
t%e => e
end if
t => skew(t)
t => split(t)
return
end function insert
! 要素の削除
recursive function delete(n,e) result(t)
implicit none
type(t_node), pointer, intent(in) :: n
type(t_entry), pointer, intent(in) :: e
type(t_node), pointer :: t, l
t => n
if (.not.associated(n)) return
if (e%key < t%e%key) then
t%left => delete(t%left,e)
else if (e%key > t%e%key) then
t%right => delete(t%right,e)
else
if (is_leaf(t)) then ! changed 2019/09/26
t => null()
return
end if
if (.not.associated(t%left)) then
l => successor(t)
t%right => delete(t%right,l%e)
t%e => l%e
else
l => predecessor(t)
t%left => delete(t%left,l%e)
t%e => l%e
end if
end if
t => decrease_level(t)
t => skew(t)
t%right => skew(t%right)
if (associated(t%right)) t%right%right => skew(t%right%right)
t => split(t)
t%right => split(t%right)
return
end function delete
! 削除の操作でノードの階層が変わるときの操作
function decrease_level(n) result(t)
type(t_node), pointer, intent(in) :: n
type(t_node), pointer :: t
integer :: should_be
t => n
should_be = min(level(t%left),level(t%right))+1
if (t%level > should_be) then
t%level = should_be
if (level(t%right) > should_be) t%right%level = should_be
end if
end
! 木の全要素の削除
recursive subroutine release_tree(t)
implicit none
type(t_node), pointer, intent(inout) :: t
if (.not.associated(t)) return
call release_tree(t%left)
call release_tree(t%right)
deallocate(t)
return
end subroutine release_tree
! 木の全要素のキーの取得
recursive subroutine get_keys_list(t,keys,num)
implicit none
type(t_node), pointer, intent(in) :: t
integer, intent(inout) :: keys(:)
integer, intent(inout) :: num
if (.not.associated(t)) return
call get_keys_list(t%left,keys,num)
num = num+1
keys(num) = t%e%key
call get_keys_list(t%right,keys,num)
return
end subroutine get_keys_list
! マップのサイズ
integer function size_of(map)
implicit none
type(t_tree_map), intent(in) :: map
size_of = tree_size(map%root)
return
end function size_of
! マップの全要素の削除
subroutine clear(map)
implicit none
type(t_tree_map), intent(inout) :: map
call release_tree(map%root)
map%root => null()
return
end subroutine clear
! マップのデフォルト値の変更
subroutine set_default(map,deflt)
implicit none
type(t_tree_map), intent(inout) :: map
integer, intent(in) :: deflt
map%deflt = deflt
return
end subroutine set_default
! マップへの要素の挿入
subroutine put_entry(map,e)
implicit none
type(t_tree_map), intent(inout) :: map
type(t_entry), pointer, intent(in) :: e
map%root => insert(map%root,e)
return
end subroutine put_entry
! マップから要素の削除
subroutine remove_entry(map,e)
implicit none
type(t_tree_map), intent(inout) :: map
type(t_entry), pointer, intent(in) :: e
map%root => delete(map%root,e)
return
end subroutine remove_entry
! マップから要素の取得
function get_entry(map,e) result(ret)
implicit none
type(t_tree_map), intent(inout) :: map
type(t_entry), pointer, intent(in) :: e
type(t_node), pointer :: n
type(t_entry), pointer :: ret
ret => null()
n => map%root
do while (associated(n))
if (e%key < n%e%key) then
n => n%left
else if (e%key > n%e%key) then
n => n%right
else
ret => n%e
return
end if
end do
return
end function get_entry
! マップが要素を保持しているか
function contain_entry(map,e) result(ret)
implicit none
type(t_tree_map), intent(inout) :: map
type(t_entry), pointer, intent(in) :: e
type(t_node), pointer :: n
logical :: ret
ret = .false.
n => map%root
do while (associated(n))
if (e%key < n%e%key) then
n => n%left
else if (e%key > n%e%key) then
n => n%right
else
ret = .true.
return
end if
end do
return
end function contain_entry
! キーが最小の要素の取得
function get_first_entry(map) result(ret)
implicit none
type(t_tree_map), intent(inout) :: map
type(t_node), pointer :: n
type(t_entry), pointer :: ret
ret => null()
n => map%root
if (.not.associated(n)) return
do while (associated(n%left))
n => n%left
end do
ret => n%e
return
end function get_first_entry
! キーが最小の要素の削除
function poll_first_entry(map) result(ret)
implicit none
type(t_tree_map), intent(inout) :: map
type(t_node), pointer :: n
type(t_entry), pointer :: ret
ret => null()
n => map%root
if (.not.associated(n)) return
do while (associated(n%left))
n => n%left
end do
ret => n%e
map%root => delete(map%root,ret)
return
end function poll_first_entry
! キーが最大の要素の取得
function get_last_entry(map) result(ret)
implicit none
type(t_tree_map), intent(inout) :: map
type(t_node), pointer :: n
type(t_entry), pointer :: ret
ret => null()
n => map%root
if (.not.associated(n)) return
do while (associated(n%right))
n => n%right
end do
ret => n%e
return
end function get_last_entry
! キーが最大の要素の削除
function poll_last_entry(map) result(ret)
implicit none
type(t_tree_map), intent(inout) :: map
type(t_node), pointer :: n
type(t_entry), pointer :: ret
ret => null()
n => map%root
if (.not.associated(n)) return
do while (associated(n%right))
n => n%right
end do
ret => n%e
map%root => delete(map%root,ret)
return
end function poll_last_entry
! キーが指定した要素のキー以下の要素の取得
function floor_entry(map,e) result(ret)
implicit none
type(t_tree_map), intent(inout) :: map
type(t_entry), pointer, intent(in) :: e
type(t_node), pointer :: n
type(t_entry), pointer :: ret
ret => null()
n => map%root
do while (associated(n))
if (e%key < n%e%key) then
n => n%left
else if (e%key > n%e%key) then
if (.not.associated(ret)) then
ret => n%e
cycle
end if
if (e%key-ret%key > e%key-n%e%key) ret => n%e
n => n%right
else
ret => n%e
return
end if
end do
return
end function floor_entry
! キーが指定した要素のキーより小さいの要素の取得
function lower_entry(map,e) result(ret)
implicit none
type(t_tree_map), intent(inout) :: map
type(t_entry), pointer, intent(in) :: e
type(t_entry), pointer :: ret
ret => floor_entry(map,new_entry(e%key-1,0))
return
end function lower_entry
! キーが指定した要素のキー以上の要素の取得
function ceiling_entry(map,e) result(ret)
implicit none
type(t_tree_map), intent(inout) :: map
type(t_entry), pointer, intent(in) :: e
type(t_node), pointer :: n
type(t_entry), pointer :: ret
ret => null()
n => map%root
do while (associated(n))
if (e%key < n%e%key) then
if (.not.associated(ret)) then
ret => n%e
cycle
end if
if (e%key-ret%key < e%key-n%e%key) ret => n%e
n => n%left
else if (e%key > n%e%key) then
n => n%right
else
ret => n%e
return
end if
end do
return
end function ceiling_entry
! キーが指定した要素のキーより大きいの要素の取得
function higher_entry(map,e) result(ret)
implicit none
type(t_tree_map), intent(inout) :: map
type(t_entry), pointer, intent(in) :: e
type(t_entry), pointer :: ret
ret => ceiling_entry(map,new_entry(e%key+1,0))
return
end function higher_entry
! マップの全ての要素のキーの取得
subroutine get_keys(map,keys,num)
implicit none
type(t_tree_map), intent(inout) :: map
integer, intent(inout) :: keys(:)
integer, intent(inout) :: num
keys = 0
num = 0
call get_keys_list(map%root,keys,num)
return
end subroutine get_keys
! キーと値の指定によるマップへの要素の挿入
subroutine put(map,key,val)
implicit none
type(t_tree_map), intent(inout) :: map
integer, intent(in) :: key
integer, intent(in) :: val
call put_entry(map,new_entry(key,val))
return
end subroutine put
! キーの指定によるマップから要素の削除
subroutine remove(map,key)
implicit none
type(t_tree_map), intent(inout) :: map
integer, intent(in) :: key
call remove_entry(map,new_entry(key,0))
return
end subroutine remove
! キーの指定によるマップから要素の取得
function get(map,key) result(val)
implicit none
type(t_tree_map), intent(inout) :: map
integer, intent(in) :: key
type(t_entry), pointer :: tmp
integer :: val
val = map%deflt
tmp => get_entry(map,new_entry(key,0))
if (.not.associated(tmp)) return
val = tmp%val
return
end function get
! キーの指定によるマップが要素を保持しているか
logical function contain(map,key)
implicit none
type(t_tree_map), intent(inout) :: map
integer, intent(in) :: key
contain = contain_entry(map,new_entry(key,0))
return
end function contain
! 最小のキーの取得
function get_first_key(map) result(key)
implicit none
type(t_tree_map), intent(inout) :: map
type(t_entry), pointer :: tmp
integer :: key
key = map%deflt
tmp => get_first_entry(map)
if (.not.associated(tmp)) return
key = tmp%key
return
end function get_first_key
! 最大のキーの取得
function get_last_key(map) result(key)
implicit none
type(t_tree_map), intent(inout) :: map
type(t_entry), pointer :: tmp
integer :: key
key = map%deflt
tmp => get_last_entry(map)
if (.not.associated(tmp)) return
key = tmp%key
return
end function get_last_key
! 指定したキー以下のキーの取得
function floor_key(map,key) result(ret)
implicit none
type(t_tree_map), intent(inout) :: map
integer, intent(in) :: key
type(t_entry), pointer :: tmp
integer :: ret
ret = map%deflt
tmp => floor_entry(map,new_entry(key,0))
if (.not.associated(tmp)) return
ret = tmp%key
return
end function floor_key
! 指定したキーより小さいのキーの取得
function lower_key(map,key) result(ret)
implicit none
type(t_tree_map), intent(inout) :: map
integer, intent(in) :: key
type(t_entry), pointer :: tmp
integer :: ret
ret = map%deflt
tmp => lower_entry(map,new_entry(key,0))
if (.not.associated(tmp)) return
ret = tmp%key
return
end function lower_key
! 指定したキー以上のキーの取得
function ceiling_key(map,key) result(ret)
implicit none
type(t_tree_map), intent(inout) :: map
integer, intent(in) :: key
type(t_entry), pointer :: tmp
integer :: ret
ret = map%deflt
tmp => ceiling_entry(map,new_entry(key,0))
if (.not.associated(tmp)) return
ret = tmp%key
return
end function ceiling_key
! 指定したキーより大きいキーの取得
function higher_key(map,key) result(ret)
implicit none
type(t_tree_map), intent(inout) :: map
integer, intent(in) :: key
type(t_entry), pointer :: tmp
integer :: ret
ret = map%deflt
tmp => higher_entry(map,new_entry(key,0))
if (.not.associated(tmp)) return
ret = tmp%key
return
end function higher_key
end module mod_tree_map
また、実装できているか確認するために以下のプログラムも書きました。
program test_tree_map
use mod_tree_map
implicit none
type(t_tree_map) :: map
integer :: keys(100), vals(100), num, i
call put(map,1,10)
call put(map,8,80)
call put(map,4,40)
call put(map,2,20)
call put(map,5,50)
call put(map,9,90)
call put(map,3,30)
call put(map,6,60)
call put(map,7,70)
write(*,'(a)',advance='no') "keys: "
call get_keys(map,keys,num)
call output(keys(1:num))
write(*,'(a)',advance='no') "vals: "
do i = 1, num
vals(i) = get(map,keys(i))
end do
call output(vals(1:num))
write(*,'("key: ",i0," val: ",i0)') 10, get(map,10)
write(*,'("first key: ",i0)') get_first_key(map)
write(*,'("last key: ",i0)') get_last_key(map)
write(*,'("lower key of ",i0,": ",i0)') 5, lower_key(map,5)
write(*,'("floor key of ",i0,": ",i0)') 5, floor_key(map,5)
write(*,'("ceiling key of ",i0,": ",i0)') 5, ceiling_key(map,5)
write(*,'("higher key of ",i0,": ",i0)') 5, higher_key(map,5)
call put(map,5,500)
write(*,'("key: ",i0," val: ",i0)') 5, get(map,5)
write(*,'("size: ",i0)') size_of(map)
call clear(map)
write(*,'("size: ",i0)') size_of(map)
call put(map,1,10)
call put(map,5,50)
call put(map,8,80)
call put(map,3,30)
call put(map,7,70)
call put(map,4,40)
call put(map,2,20)
call put(map,9,90)
call put(map,6,60)
call set_default(map,-1000000)
write(*,'("key: ",i0," val: ",i0)') 5, get(map,5)
write(*,'("key: ",i0," val: ",i0)') 10, get(map,10)
write(*,'(a)',advance='no') "keys: "
call get_keys(map,keys,num)
call output(keys(1:num))
write(*,'(a)',advance='no') "vals: "
do i = 1, num
vals(i) = get(map,keys(i))
end do
call output(vals(1:num))
write(*,'("lower key of ",i0,": ",i0)') 1, lower_key(map,1)
write(*,'("floor key of ",i0,": ",i0)') 1, floor_key(map,1)
write(*,'("ceiling key of ",i0,": ",i0)') 9, ceiling_key(map,9)
write(*,'("higher key of ",i0,": ",i0)') 9, higher_key(map,9)
stop
contains
subroutine output(a)
implicit none
integer, intent(in) :: a(:)
integer :: n, i
n = size(a)
do i = 1, n
write(*,'(i0,x)',advance='no') a(i)
end do
write(*,*)
return
end subroutine output
end program test_tree_map
実行結果は以下のとおりです。
keys: 1 2 3 4 5 6 7 8 9
vals: 10 20 30 40 50 60 70 80 90
key: 10 val: -1
first key: 1
last key: 9
lower key of 5: 4
floor key of 5: 5
ceiling key of 5: 5
higher key of 5: 6
key: 5 val: 500
size: 9
size: 0
key: 5 val: 50
key: 10 val: -1000000
keys: 1 2 3 4 5 6 7 8 9
vals: 10 20 30 40 50 60 70 80 90
lower key of 1: -1000000
floor key of 1: 1
ceiling key of 9: 9
higher key of 9: -1000000
これで Fortran でも連想配列を使うことができるようになりました。 map を実装すると、同時に set も実装できていることになるので、これで Fortran で set が必要になっても安心です。ただ、キーと値の型ごとに作り直さなければならないのが面倒ですね。
追記 2019/08/11
不要な部分を削り、少し短くなりました (ソースコード) 。