Fortranでコードを書いて数値計算をしようとする人がたまに遭遇するのは、疎行列ベクトル積です。
疎行列形式を使えば行列の非ゼロ要素だけを持っていられるので、メモリも少なくて済みますし、計算も速いです。
しかし、疎行列とベクトルの積をFortranで実装しようとする時、全国各地で車輪の再発明が行われているように思います(多分MKLとかのマニュアルを眺めながら試行錯誤したり)。
COO形式、CSR形式、CSC形式、などなど色々な格納形式があって、どれも一体どうやって使えばいいのか、などなど考えたりすることもあると思います。
特に、i,j成分の値をセットするとか更新するとか、頭を使います。
なお、c++だとBoostを使うと疎行列を簡単に扱うことができて、A(i,j)=vみたいな形で値を代入できます。
また、Juliaでも、A[i,j]=vみたいな形で値を代入できます。Pythonでも多分できます。
「研究のコードはFortranで書き始めてるし、今更別の言語に行くのも大変...」そんな方もいると思います。
というわけで、モダンなFortranで疎行列の扱いを簡単にする方法を記述します。
オブジェクト指向なFortranの紹介にもなっています。
使用する機能
- module
- type
- オブジェクト指向Fortran(classの導入)
- 演算子のオーバーロード(オプション。なくてもよい)
- MKL(なくてもできるがあったほうが速い)
目的
- バグの混入を減らし、コーディングに時間を取られすぎないようにする
- 拡張しても修正点が少なくて済むようなコーディングを行う
今時Fortranを使う人は数値計算が目的でやっていることが多いわけですから、本来やりたい数値計算があって、そのためにプログラムを書いているはずです。
しかしFortranの不便な部分をそのままにして昔ながらのコーディングをすると、やりたい数値計算を実行するためのハードルが上がってしまうわけで、それは勿体無いなと思います。また、初めて研究を開始して初めてプログラミングをする時に参考にするコードが指導教官や先輩のFortranコードだったりすると、それは昔ながらのFortran(場合によってはFORTRAN77)のスタイルを見て勉強することになってしまいます。また、先輩や指導教官に聞く時に知らない言語を使っているとアドバイスがもらいにくくなる側面もあり、その結果Fortranをやることになる、ということもありますよね。
今ならFortranも書きやすくできます。モダンなFortranを使えばこれまでの古いコードも参考にしつつ新しいわかりやすいコードが書けます。Fortranではあるわけですから、先輩にも聞きやすいですよね。
なお、先輩とか指導教官とかがコードを持っていないしそもそもプログラムを相談できる人が身近にいない場合、Fortranを使う理由はほとんどなく(スパコンで大規模並列計算する場合とかくらいですかね)、その場合はJuliaがオススメです。
#バージョン
gcc version 9.2.0で動作確認
コード例
まず、コードの例を見てみましょう。
program main
use CSRmodules
implicit none
type(CSRcomplex)::H
integer::N
complex(8)::v,alpha,beta
integer::i,j
real(8)::mu,t
complex(8),allocatable::x(:)
complex(8),allocatable::y(:)
complex(8),allocatable::ytemp(:)
complex(8),allocatable::z(:)
N = 10
v = -1d0
t = -1d0
mu = -1.5d0
H = CSRcomplex(2*N)
do i = 1,N
j = i
v = -mu
call H%set(v,i,j)
call H%set(-v,i+N,j+N)
v = t
j = i+1
if (j > 0 .and. j < N+1) then
call H%set(v,i,j)
call H%set(v,j,i)
call H%set(-v,i+N,j+N)
call H%set(-v,j+N,i+N)
end if
j = i-1
v =t
if (j > 0 .and. j < N+1) then
call H%set(v,i,j)
call H%set(v,j,i)
call H%set(-v,i+N,j+N)
call H%set(-v,j+N,i+N)
end if
j = i+N
v = 0.5d0
call H%set(v,i,j)
call H%set(v,j,i)
end do
call H%print()
do i=1,N
j = i+N
v = 0.6d0
call H%update(v,i,j)
call H%update(v,j,i)
end do
call H%print()
allocate(x(1:2*N))
x = 0d0
allocate(y(1:2*N))
allocate(ytemp(1:2*N))
y = 0d0
ytemp = 0d0
allocate(z(1:2*N))
z(5) = 10d0
x(3) = 1d0
call H%matmul(x,y)
write(*,*) "matmul(x,y) ", y
ytemp = H*x
write(*,*) "H*x", ytemp
alpha = 2d0
beta = 3d0
call H%matmul2(x,y,z,alpha,beta)
write(*,*) "y = alpha*A*x+beta*z",y
end program main
これは、疎行列をCSR形式で定義して、それに値を代入(H%set(v,i,j))しているコードです。また、H%matmul(x,y)
は疎行列Hとベクトルyの積、H*x
も疎行列Hとベクトルの積です。
値の更新はH%update(v,i,j)
でやっています。
見ればわかりますように、CSR形式などでよく出てくる、valとかrowとかcolが全く顔を出していません。
通常の密行列のように計算ができていますね。
これを実現しているのが、CSRmodules モジュールです。
では、具体的に中身を見ていきましょう。
モジュールの中身
typeの定義
まず、typeを定義します。Fortranを数値計算で使う物理系の研究室だと、日本語で構造体と呼ばれるtypeはあまり使わずにコードが書かれている気がします。その理由(あるいは気のせい)は、FORTRAN77などでのプログラミングではあまり使われていなかったからかと思います。あるいは、構造体を使うと最適化がうまくいかない、とかでしょうか。あるいは、単純に知らない、場合もあると思います。
しかし、typeはとても便利です。subroutineに大量の引数があって、毎回毎回書くのがしんどかったりバグが出そうで怖かったりする方はtypeを使った方がよいでしょう。
以下のコードは全てmodule CSRmodules
の内部に定義されているとします。また最後に全体のコードも載せます。
type CSRcomplex
private
integer::N !正方行列を仮定している。
complex(8),allocatable::val(:)
integer,allocatable::row(:)
integer,allocatable::col(:)
contains
procedure::set => set_c !(i,j)成分にvを代入 set(v,i,j)
procedure::print => print_csr !行列の中身をprint
procedure::nonzero => get_nonzeronum !非ゼロ要素の数を数える
procedure::update => update_c !(i,j)成分の値をアップデート。なければエラーで止まる。
procedure::matmul => matmul_axy !y = A*x: matmul(x,y)
procedure::matmul2 => matmul_axy2 !y = alpha*A*x +beta*b: matmul2(x,y,b,alpha,beta)
end type
まず、CSRcomplex
typeを定義しました。このtypeにはN
、val
、row
、col
の四つの変数が格納されています。取り出すにはH%val
などとパーセントをつけます。PythonやJuliaだとパーセントの代わりに.ですね。これで、引数をひとつだけ、typeひとつを定義しておけば、わざわざsubroutineでvalだのrowだのを使わなくて済みます。
残りのprocedure
はオブジェクト指向の部分です。
Fortranで数値計算をしている人はオブジェクト指向恐怖症な人もいるかと思いますが、Fortranのオブジェクト指向は難しくありません。これはただの書き方の方便と思ってください。例えば、subroutineでmatmul(A,x,y)
というものがあったとします。オブジェクト指向ではこれをA%matmul(x,y)
とします。ただそれだけです。
もう少し気になる方は
「Fortranからみたオブジェクト指向:オブジェクト指向でFortranコードを書く」
https://qiita.com/cometscome_phys/items/f87080286c6fdf72e49a
を参考にしてみてください。
日本語で考えると、「matmulというsubroutineはAとxとyが引数」だったものが、「Aが持つmatmulというsubroutineはxとyが引数」となっています。Aがmatmulを所有しているので、Aというオブジェクトがmatmulという動作を持つと言い換えることができて、これがオブジェクト指向です。
ということで、「type CSRcomplex
が持つ動作」をprocedure
として持つことができています。
初期化(コンストラクタ)
次に、typeを定義したら初期化したい、と思うと思います。これは、
interface CSRcomplex
module procedure::init_CSRcomplex
module procedure::init_CSRcomplex_initialnum
end interface CSRcomplex
でOKです。これは、引数がひとつの関数H=CSRcomplex(N)
と二つの関数H=CSRcomplex(N,m)
をひとつの名前CSRcomplex
で呼ぶためのinterfaceです。
馴染みのあるものとしては、倍精度の単精度もsin関数はsin(x)
で書けるものがありますね。このようなものを総称名と呼びます。昔のFORTRANだと、引数xが倍精度か単精度かなどでdsin(x)
などを呼んでいました。
コンストラクタとも呼ばれるこの関数の中身は
type(CSRcomplex) function init_CSRcomplex(N) result(A)!疎行列を初期化
implicit none
integer,intent(in)::N
A = init_CSRcomplex_initialnum(N,N) !配列の初期の長さをNとした。
return
end function
type(CSRcomplex) function init_CSRcomplex_initialnum(N,initialnum) result(A)!疎行列を初期化。配列の初期の長さはinitialnumに設定。
implicit none
integer,intent(in)::N,initialnum
A%N = N
allocate(A%row(N+1))
allocate(A%val(initialnum)) !長さinitialnumに初期化
allocate(A%col(initialnum)) !長さinitialnumに初期化
A%val = 0d0
A%col = 0
A%row = 1
end function
です。疎行列CSRで必要なvalやrowやcolの配列を確保したり0で埋めたりしています。
ここで、先ほど定義したtypeにprivate属性がついていることに注意してください。privateがついているということは、このモジュールの外からはvalやrowやcolにアクセスできない、ということです。外のコードで不用意に内部をいじることができてしまうとバグの温床になりますので、必ずこのモジュール内で定義した関数やsubroutineを使います。
##演算子のオーバーロード
subroutine全部を説明すると大変なので一部を抜粋します。
例えば、疎行列ベクトル積をH*x
みたいにシンプルに書きたい場合には、掛け算「*」を定義すればよいです。これは、
interface operator(*)
module procedure mult
end interface
として、関数mult
を定義すればOKです。これで、「*」を使うとmultが呼ばれます。
##モジュール全体
モジュール全体を以下に示します。疎行列の要素を取り出すget系の関数は未実装です。しかし、要素に値を代入するsetは実装しました。これで、valとかrowとかcolとかを考えなくても、このmoduleさえデバッグが終われば、安心して他のコードでCSR形式の疎行列を使うことができます。
なお、疎行列関連の実装はJuliaのSparseArrays.jlパッケージのソースを参考にしました。これはCSC形式ですのでCSRに書き換えました。
module CSRmodules
implicit none
private
public CSRcomplex
public :: operator(*)
type CSRcomplex
private
integer::N !正方行列を仮定している。
complex(8),allocatable::val(:)
integer,allocatable::row(:)
integer,allocatable::col(:)
contains
procedure::set => set_c !(i,j)成分にvを代入 set(v,i,j)
procedure::print => print_csr !行列の中身をprint
procedure::nonzero => get_nonzeronum !非ゼロ要素の数を数える
procedure::update => update_c !(i,j)成分の値をアップデート。なければエラーで止まる。
procedure::matmul => matmul_axy !y = A*x: matmul(x,y)
procedure::matmul2 => matmul_axy2 !y = alpha*A*x +beta*b: matmul2(x,y,b,alpha,beta)
end type
interface CSRcomplex
module procedure::init_CSRcomplex
module procedure::init_CSRcomplex_initialnum
end interface CSRcomplex
interface insert_element
module procedure::insert_c
module procedure::insert_int
end interface
interface operator(*)
module procedure mult
end interface
contains
function mult(A,x) result(y)
implicit none
type(CSRcomplex),intent(in)::A
complex(8),intent(in)::x(:)
complex(8)::y(ubound(x,1))
if (A%N .ne. ubound(x,1)) then
write(*,*) "error in CSRmodules! size mismatch"
stop
end if
y(1:ubound(x,1)) = 0d0
call mkl_zcsrgemv("N", ubound(x,1), A%val, A%row, A%col, x, y)
end function
subroutine matmul_axy(self,x,y)
implicit none
class(CSRcomplex),intent(in)::self
complex(8),intent(in)::x(:)
complex(8),intent(out)::y(*)
!allocate(y(1:ubound(x,1)))
if (self%N .ne. ubound(x,1)) then
write(*,*) "error in CSRmodules! size mismatch"
stop
end if
y(1:ubound(x,1)) = 0d0
call mkl_zcsrgemv("N", ubound(x,1), self%val, self%row, self%col, x, y)
return
end subroutine
subroutine matmul_axy2(self,x,y,b,alpha,beta) !y = alpha*A*x+ beta*b
implicit none
class(CSRcomplex),intent(in)::self
complex(8),intent(in)::x(:)
complex(8),intent(in)::b(:)
complex(8),intent(in)::alpha,beta
complex(8),intent(out)::y(*)
!allocate(y(1:ubound(x,1)))
if (self%N .ne. ubound(x,1)) then
write(*,*) "error in CSRmodules! size mismatch"
stop
end if
y(1:ubound(x,1)) = 0d0
call mkl_zcsrgemv("N", ubound(x,1), self%val, self%row, self%col, x, y)
y(1:ubound(x,1)) = alpha*y(1:ubound(x,1)) + beta*b(1:ubound(x,1))
return
end subroutine
type(CSRcomplex) function init_CSRcomplex(N) result(A)!疎行列を初期化
implicit none
integer,intent(in)::N
A = init_CSRcomplex_initialnum(N,N) !配列の初期の長さをNとした。
return
end function
type(CSRcomplex) function init_CSRcomplex_initialnum(N,initialnum) result(A)!疎行列を初期化。配列の初期の長さはinitialnumに設定。
implicit none
integer,intent(in)::N,initialnum
A%N = N
allocate(A%row(N+1))
allocate(A%val(initialnum)) !長さinitialnumに初期化
allocate(A%col(initialnum)) !長さinitialnumに初期化
A%val = 0d0
A%col = 0
A%row = 1
end function
subroutine print_csr(self)
implicit none
class(CSRcomplex)::self
integer::i,j,k
do i=1,self%N
do k=self%row(i),self%row(i+1) -1
j = self%col(k)
write(*,*) i,j,self%val(k)
end do
end do
return
end subroutine print_csr
subroutine set_c(self,v,i,j)
implicit none
class(CSRcomplex)::self
complex(8),intent(in)::v
integer,intent(in)::i,j
integer::rowifirstk,rowilastk
integer::searchk
integer::nonzeros
integer::m
call boundcheck(self%N,i,j)
nonzeros = self%nonzero()
rowifirstk = self%row(i)
rowilastk = self%row(i+1)-1
call searchsortedfirst(self%col(1:nonzeros),j,rowifirstk,rowilastk,searchk)
!write(*,*) "s",searchk,i,j,rowifirstk,rowilastk
if (searchk .le. rowilastk .and. self%col(searchk) .eq. j) then
self%val(searchk) = v
return
end if
if (abs(v) .ne. 0d0) then
call insert_element(self%col,searchk,j,nonzeros)
call insert_element(self%val,searchk,v,nonzeros)
do m=i+1,self%N+1
self%row(m) = self%row(m) + 1
end do
end if
return
end subroutine set_c
subroutine insert_c(vec,pos,item,nz)
implicit none
complex(8),intent(inout),allocatable::vec(:)
integer,intent(in)::pos,nz
complex(8),intent(in)::item
integer::vlength
complex(8),allocatable::temp(:)
vlength = ubound(vec,1)
if (nz .ge. vlength) then !確保している配列valの長さが足りない時
allocate(temp(vlength+1))
if (pos > 1) then
temp(1:pos-1)= vec(1:pos-1)
end if
temp(pos+1:nz+1) = vec(pos:nz)
temp(pos) = item
deallocate(vec)
allocate(vec(vlength+vlength))
vec(1:vlength+1) = temp(1:vlength+1)
vec(vlength+2:vlength+vlength) = 0d0
else
vec(pos+1:nz+1) = vec(pos:nz)
vec(pos) = item
end if
return
end subroutine insert_c
subroutine insert_int(vec,pos,item,nz)
implicit none
integer,intent(inout),allocatable::vec(:)
integer,intent(in)::pos,item,nz
integer::vlength
integer,allocatable::temp(:)
vlength = ubound(vec,1)
if (nz .ge. vlength) then !確保している配列valの長さが足りない時
allocate(temp(vlength+1))
if (pos > 1) then
temp(1:pos-1)= vec(1:pos-1)
end if
temp(pos+1:nz+1) = vec(pos:nz)
temp(pos) = item
deallocate(vec)
allocate(vec(vlength+vlength))
vec(1:vlength+1) = temp(1:vlength+1)
vec(vlength+2:vlength+vlength) = 0
else
vec(pos+1:nz+1) = vec(pos:nz)
vec(pos) = item
end if
return
end subroutine insert_int
subroutine update_c(self,v,i,j)
implicit none
class(CSRcomplex)::self
complex(8),intent(in)::v
integer,intent(in)::i,j
integer::rowifirstk,rowilastk
integer::searchk
call boundcheck(self%N,i,j)
rowifirstk = self%row(i)
rowilastk = self%row(i+1)-1
call searchsortedfirst(self%col,j,rowifirstk,rowilastk,searchk)
if (searchk .le. rowilastk .and. self%col(searchk) .eq. j) then
self%val(searchk) = v
return
else
write(*,*) "error! in CSRmodules. There is no entry at",i,j
stop
end if
return
end subroutine update_c
integer function get_nonzeronum(self) result(nonzeros)
implicit none
class(CSRcomplex)::self
nonzeros = self%row(self%N+1)
return
end function
subroutine boundcheck(N,i,j) !配列外参照をチェック
implicit none
integer,intent(in)::N,i,j
if (i < 1 .or. i > N) then
write(*,*) "error! i should be [1:N]."
stop
end if
if (j < 1 .or. j > N) then
write(*,*) "error! j should be [1:N]."
stop
end if
return
end subroutine
subroutine searchsortedfirst(vec,i,istart,iend,searchk)
implicit none
integer,intent(in)::vec(:)
integer,intent(in)::i,istart,iend
integer,intent(out)::searchk
integer::k
if (iend - istart < 0) then
searchk = istart
return
end if
searchk = iend+1!ubound(vec,1)+1
do k=istart,iend
if (vec(k) .eq. i) then
searchk = k
return
end if
end do
return
end subroutine searchsortedfirst
end module CSRmodules
最初のprogram mainと組み合われば、テストできると思います。
MKLを使っていますので、コンパイルにはMKLが必要です。