LoginSignup
14
13

More than 3 years have passed since last update.

Fortranで疎行列:オブジェクト指向Fortranで楽々プログラミング

Posted at

Fortranでコードを書いて数値計算をしようとする人がたまに遭遇するのは、疎行列ベクトル積です。
疎行列形式を使えば行列の非ゼロ要素だけを持っていられるので、メモリも少なくて済みますし、計算も速いです。
しかし、疎行列とベクトルの積をFortranで実装しようとする時、全国各地で車輪の再発明が行われているように思います(多分MKLとかのマニュアルを眺めながら試行錯誤したり)。
COO形式、CSR形式、CSC形式、などなど色々な格納形式があって、どれも一体どうやって使えばいいのか、などなど考えたりすることもあると思います。
特に、i,j成分の値をセットするとか更新するとか、頭を使います。
なお、c++だとBoostを使うと疎行列を簡単に扱うことができて、A(i,j)=vみたいな形で値を代入できます。
また、Juliaでも、A[i,j]=vみたいな形で値を代入できます。Pythonでも多分できます。
「研究のコードはFortranで書き始めてるし、今更別の言語に行くのも大変...」そんな方もいると思います。
というわけで、モダンなFortranで疎行列の扱いを簡単にする方法を記述します。
オブジェクト指向なFortranの紹介にもなっています。

使用する機能

  1. module
  2. type
  3. オブジェクト指向Fortran(classの導入)
  4. 演算子のオーバーロード(オプション。なくてもよい)
  5. MKL(なくてもできるがあったほうが速い)

目的

  1. バグの混入を減らし、コーディングに時間を取られすぎないようにする
  2. 拡張しても修正点が少なくて済むようなコーディングを行う

今時Fortranを使う人は数値計算が目的でやっていることが多いわけですから、本来やりたい数値計算があって、そのためにプログラムを書いているはずです。
しかしFortranの不便な部分をそのままにして昔ながらのコーディングをすると、やりたい数値計算を実行するためのハードルが上がってしまうわけで、それは勿体無いなと思います。また、初めて研究を開始して初めてプログラミングをする時に参考にするコードが指導教官や先輩のFortranコードだったりすると、それは昔ながらのFortran(場合によってはFORTRAN77)のスタイルを見て勉強することになってしまいます。また、先輩や指導教官に聞く時に知らない言語を使っているとアドバイスがもらいにくくなる側面もあり、その結果Fortranをやることになる、ということもありますよね。
今ならFortranも書きやすくできます。モダンなFortranを使えばこれまでの古いコードも参考にしつつ新しいわかりやすいコードが書けます。Fortranではあるわけですから、先輩にも聞きやすいですよね。
なお、先輩とか指導教官とかがコードを持っていないしそもそもプログラムを相談できる人が身近にいない場合、Fortranを使う理由はほとんどなく(スパコンで大規模並列計算する場合とかくらいですかね)、その場合はJuliaがオススメです。

バージョン

gcc version 9.2.0で動作確認

コード例

まず、コードの例を見てみましょう。

program main
    use CSRmodules
    implicit none
    type(CSRcomplex)::H
    integer::N
    complex(8)::v,alpha,beta
    integer::i,j
    real(8)::mu,t
    complex(8),allocatable::x(:)
    complex(8),allocatable::y(:)
    complex(8),allocatable::ytemp(:)
    complex(8),allocatable::z(:)

    N = 10
    v = -1d0
    t = -1d0
    mu = -1.5d0

    H = CSRcomplex(2*N)
    do i = 1,N
        j = i
        v = -mu

        call H%set(v,i,j)
        call H%set(-v,i+N,j+N)

        v = t
        j = i+1
        if (j > 0 .and. j < N+1) then
            call H%set(v,i,j)
            call H%set(v,j,i)

            call H%set(-v,i+N,j+N)
            call H%set(-v,j+N,i+N)
        end if

        j = i-1
        v =t

        if (j > 0 .and. j < N+1) then
            call H%set(v,i,j)
            call H%set(v,j,i)

            call H%set(-v,i+N,j+N)
            call H%set(-v,j+N,i+N)
        end if

        j = i+N
        v = 0.5d0
        call H%set(v,i,j)
        call H%set(v,j,i)

    end do

    call H%print()

    do i=1,N
        j = i+N
        v = 0.6d0
        call H%update(v,i,j)
        call H%update(v,j,i)

    end do
    call H%print()

    allocate(x(1:2*N))
    x = 0d0
    allocate(y(1:2*N))
    allocate(ytemp(1:2*N))
    y = 0d0
    ytemp = 0d0
    allocate(z(1:2*N))
    z(5) = 10d0
    x(3) = 1d0
    call H%matmul(x,y)
    write(*,*) "matmul(x,y) ", y

    ytemp = H*x
    write(*,*) "H*x", ytemp


    alpha = 2d0
    beta = 3d0
    call H%matmul2(x,y,z,alpha,beta)
    write(*,*) "y = alpha*A*x+beta*z",y


end program main

これは、疎行列をCSR形式で定義して、それに値を代入(H%set(v,i,j))しているコードです。また、H%matmul(x,y)は疎行列Hとベクトルyの積、H*xも疎行列Hとベクトルの積です。
値の更新はH%update(v,i,j)でやっています。
見ればわかりますように、CSR形式などでよく出てくる、valとかrowとかcolが全く顔を出していません。
通常の密行列のように計算ができていますね。
これを実現しているのが、CSRmodules モジュールです。
では、具体的に中身を見ていきましょう。

モジュールの中身

typeの定義

まず、typeを定義します。Fortranを数値計算で使う物理系の研究室だと、日本語で構造体と呼ばれるtypeはあまり使わずにコードが書かれている気がします。その理由(あるいは気のせい)は、FORTRAN77などでのプログラミングではあまり使われていなかったからかと思います。あるいは、構造体を使うと最適化がうまくいかない、とかでしょうか。あるいは、単純に知らない、場合もあると思います。
しかし、typeはとても便利です。subroutineに大量の引数があって、毎回毎回書くのがしんどかったりバグが出そうで怖かったりする方はtypeを使った方がよいでしょう。
以下のコードは全てmodule CSRmodulesの内部に定義されているとします。また最後に全体のコードも載せます。

    type CSRcomplex
        private
        integer::N !正方行列を仮定している。
        complex(8),allocatable::val(:)
        integer,allocatable::row(:)
        integer,allocatable::col(:)

        contains
        procedure::set => set_c !(i,j)成分にvを代入 set(v,i,j)
        procedure::print => print_csr !行列の中身をprint
        procedure::nonzero => get_nonzeronum  !非ゼロ要素の数を数える
        procedure::update => update_c !(i,j)成分の値をアップデート。なければエラーで止まる。
        procedure::matmul => matmul_axy  !y = A*x: matmul(x,y)
        procedure::matmul2 => matmul_axy2 !y = alpha*A*x +beta*b: matmul2(x,y,b,alpha,beta)
    end type

まず、CSRcomplextypeを定義しました。このtypeにはNvalrowcolの四つの変数が格納されています。取り出すにはH%valなどとパーセントをつけます。PythonやJuliaだとパーセントの代わりに.ですね。これで、引数をひとつだけ、typeひとつを定義しておけば、わざわざsubroutineでvalだのrowだのを使わなくて済みます。

残りのprocedureはオブジェクト指向の部分です。
Fortranで数値計算をしている人はオブジェクト指向恐怖症な人もいるかと思いますが、Fortranのオブジェクト指向は難しくありません。これはただの書き方の方便と思ってください。例えば、subroutineでmatmul(A,x,y)というものがあったとします。オブジェクト指向ではこれをA%matmul(x,y)とします。ただそれだけです。
もう少し気になる方は
「Fortranからみたオブジェクト指向:オブジェクト指向でFortranコードを書く」
https://qiita.com/cometscome_phys/items/f87080286c6fdf72e49a
を参考にしてみてください。
日本語で考えると、「matmulというsubroutineはAとxとyが引数」だったものが、「Aが持つmatmulというsubroutineはxとyが引数」となっています。Aがmatmulを所有しているので、Aというオブジェクトがmatmulという動作を持つと言い換えることができて、これがオブジェクト指向です。
ということで、「type CSRcomplexが持つ動作」をprocedureとして持つことができています。

初期化(コンストラクタ)

次に、typeを定義したら初期化したい、と思うと思います。これは、

    interface CSRcomplex
        module procedure::init_CSRcomplex
        module procedure::init_CSRcomplex_initialnum
    end interface CSRcomplex

でOKです。これは、引数がひとつの関数H=CSRcomplex(N)と二つの関数H=CSRcomplex(N,m)をひとつの名前CSRcomplexで呼ぶためのinterfaceです。
馴染みのあるものとしては、倍精度の単精度もsin関数はsin(x)で書けるものがありますね。このようなものを総称名と呼びます。昔のFORTRANだと、引数xが倍精度か単精度かなどでdsin(x)などを呼んでいました。

コンストラクタとも呼ばれるこの関数の中身は

    type(CSRcomplex) function init_CSRcomplex(N) result(A)!疎行列を初期化
        implicit none
        integer,intent(in)::N      
        A = init_CSRcomplex_initialnum(N,N) !配列の初期の長さをNとした。
        return
    end function

    type(CSRcomplex) function init_CSRcomplex_initialnum(N,initialnum) result(A)!疎行列を初期化。配列の初期の長さはinitialnumに設定。
        implicit none
        integer,intent(in)::N,initialnum

        A%N = N
        allocate(A%row(N+1))
        allocate(A%val(initialnum)) !長さinitialnumに初期化
        allocate(A%col(initialnum)) !長さinitialnumに初期化
        A%val = 0d0
        A%col = 0
        A%row = 1

    end function

です。疎行列CSRで必要なvalやrowやcolの配列を確保したり0で埋めたりしています。
ここで、先ほど定義したtypeにprivate属性がついていることに注意してください。privateがついているということは、このモジュールの外からはvalやrowやcolにアクセスできない、ということです。外のコードで不用意に内部をいじることができてしまうとバグの温床になりますので、必ずこのモジュール内で定義した関数やsubroutineを使います。

演算子のオーバーロード

subroutine全部を説明すると大変なので一部を抜粋します。
例えば、疎行列ベクトル積をH*xみたいにシンプルに書きたい場合には、掛け算「*」を定義すればよいです。これは、

    interface operator(*)
        module procedure mult
    end interface

として、関数multを定義すればOKです。これで、「*」を使うとmultが呼ばれます。

モジュール全体

モジュール全体を以下に示します。疎行列の要素を取り出すget系の関数は未実装です。しかし、要素に値を代入するsetは実装しました。これで、valとかrowとかcolとかを考えなくても、このmoduleさえデバッグが終われば、安心して他のコードでCSR形式の疎行列を使うことができます。
なお、疎行列関連の実装はJuliaのSparseArrays.jlパッケージのソースを参考にしました。これはCSC形式ですのでCSRに書き換えました。

module CSRmodules
    implicit none
    private

    public CSRcomplex
    public :: operator(*)

    type CSRcomplex
        private
        integer::N !正方行列を仮定している。
        complex(8),allocatable::val(:)
        integer,allocatable::row(:)
        integer,allocatable::col(:)

        contains
        procedure::set => set_c !(i,j)成分にvを代入 set(v,i,j)
        procedure::print => print_csr !行列の中身をprint
        procedure::nonzero => get_nonzeronum  !非ゼロ要素の数を数える
        procedure::update => update_c !(i,j)成分の値をアップデート。なければエラーで止まる。
        procedure::matmul => matmul_axy  !y = A*x: matmul(x,y)
        procedure::matmul2 => matmul_axy2 !y = alpha*A*x +beta*b: matmul2(x,y,b,alpha,beta)
    end type

    interface CSRcomplex
        module procedure::init_CSRcomplex
        module procedure::init_CSRcomplex_initialnum
    end interface CSRcomplex

    interface insert_element
        module procedure::insert_c
        module procedure::insert_int
    end interface

    interface operator(*)
        module procedure mult
    end interface


    contains 



    function mult(A,x) result(y)
        implicit none
        type(CSRcomplex),intent(in)::A
        complex(8),intent(in)::x(:)
        complex(8)::y(ubound(x,1))

        if (A%N .ne. ubound(x,1)) then
            write(*,*) "error in CSRmodules! size mismatch"
            stop
        end if
        y(1:ubound(x,1)) = 0d0
        call mkl_zcsrgemv("N", ubound(x,1), A%val, A%row, A%col, x, y)

    end function

    subroutine matmul_axy(self,x,y) 
        implicit none
        class(CSRcomplex),intent(in)::self
        complex(8),intent(in)::x(:)
        complex(8),intent(out)::y(*)
        !allocate(y(1:ubound(x,1)))
        if (self%N .ne. ubound(x,1)) then
            write(*,*) "error in CSRmodules! size mismatch"
            stop
        end if
        y(1:ubound(x,1)) = 0d0
        call mkl_zcsrgemv("N", ubound(x,1), self%val, self%row, self%col, x, y)

        return
    end subroutine

    subroutine matmul_axy2(self,x,y,b,alpha,beta) !y = alpha*A*x+ beta*b 
        implicit none
        class(CSRcomplex),intent(in)::self
        complex(8),intent(in)::x(:)
        complex(8),intent(in)::b(:)
        complex(8),intent(in)::alpha,beta
        complex(8),intent(out)::y(*)
        !allocate(y(1:ubound(x,1)))
        if (self%N .ne. ubound(x,1)) then
            write(*,*) "error in CSRmodules! size mismatch"
            stop
        end if
        y(1:ubound(x,1)) = 0d0
        call mkl_zcsrgemv("N", ubound(x,1), self%val, self%row, self%col, x, y)
        y(1:ubound(x,1)) = alpha*y(1:ubound(x,1)) + beta*b(1:ubound(x,1))

        return
    end subroutine    

    type(CSRcomplex) function init_CSRcomplex(N) result(A)!疎行列を初期化
        implicit none
        integer,intent(in)::N      
        A = init_CSRcomplex_initialnum(N,N) !配列の初期の長さをNとした。
        return
    end function

    type(CSRcomplex) function init_CSRcomplex_initialnum(N,initialnum) result(A)!疎行列を初期化。配列の初期の長さはinitialnumに設定。
        implicit none
        integer,intent(in)::N,initialnum

        A%N = N
        allocate(A%row(N+1))
        allocate(A%val(initialnum)) !長さinitialnumに初期化
        allocate(A%col(initialnum)) !長さinitialnumに初期化
        A%val = 0d0
        A%col = 0
        A%row = 1

    end function

    subroutine print_csr(self)
        implicit none
        class(CSRcomplex)::self
        integer::i,j,k


        do i=1,self%N
            do k=self%row(i),self%row(i+1) -1
                j = self%col(k)
                write(*,*) i,j,self%val(k)
            end do
        end do

        return
    end subroutine print_csr

    subroutine set_c(self,v,i,j)
        implicit none
        class(CSRcomplex)::self
        complex(8),intent(in)::v
        integer,intent(in)::i,j
        integer::rowifirstk,rowilastk
        integer::searchk
        integer::nonzeros
        integer::m

        call boundcheck(self%N,i,j)
        nonzeros = self%nonzero()
        rowifirstk = self%row(i)
        rowilastk = self%row(i+1)-1 
        call searchsortedfirst(self%col(1:nonzeros),j,rowifirstk,rowilastk,searchk)
        !write(*,*) "s",searchk,i,j,rowifirstk,rowilastk 

        if (searchk .le. rowilastk .and. self%col(searchk) .eq. j) then
            self%val(searchk) = v
            return
        end if

        if (abs(v) .ne. 0d0) then

            call insert_element(self%col,searchk,j,nonzeros)
            call insert_element(self%val,searchk,v,nonzeros)
            do m=i+1,self%N+1
                self%row(m) = self%row(m) + 1
            end do


        end if

        return
    end subroutine set_c

    subroutine insert_c(vec,pos,item,nz)
        implicit none
        complex(8),intent(inout),allocatable::vec(:)
        integer,intent(in)::pos,nz
        complex(8),intent(in)::item
        integer::vlength
        complex(8),allocatable::temp(:)

        vlength = ubound(vec,1)

        if (nz .ge. vlength) then !確保している配列valの長さが足りない時
            allocate(temp(vlength+1))
            if (pos > 1) then
                temp(1:pos-1)= vec(1:pos-1)
            end if
            temp(pos+1:nz+1) = vec(pos:nz)
            temp(pos) = item
            deallocate(vec)
            allocate(vec(vlength+vlength))
            vec(1:vlength+1) = temp(1:vlength+1)
            vec(vlength+2:vlength+vlength) = 0d0
        else
            vec(pos+1:nz+1) = vec(pos:nz)
            vec(pos) = item
        end if

        return
    end subroutine insert_c

    subroutine insert_int(vec,pos,item,nz)
        implicit none
        integer,intent(inout),allocatable::vec(:)
        integer,intent(in)::pos,item,nz
        integer::vlength
        integer,allocatable::temp(:)

        vlength = ubound(vec,1)

        if (nz .ge. vlength) then !確保している配列valの長さが足りない時
            allocate(temp(vlength+1))
            if (pos > 1) then
                temp(1:pos-1)= vec(1:pos-1)
            end if
            temp(pos+1:nz+1) = vec(pos:nz)
            temp(pos) = item
            deallocate(vec)
            allocate(vec(vlength+vlength))
            vec(1:vlength+1) = temp(1:vlength+1)
            vec(vlength+2:vlength+vlength) = 0
        else
            vec(pos+1:nz+1) = vec(pos:nz)
            vec(pos) = item
        end if

        return
    end subroutine insert_int

    subroutine update_c(self,v,i,j)
        implicit none
        class(CSRcomplex)::self
        complex(8),intent(in)::v
        integer,intent(in)::i,j
        integer::rowifirstk,rowilastk
        integer::searchk

        call boundcheck(self%N,i,j)
        rowifirstk = self%row(i)
        rowilastk = self%row(i+1)-1 
        call searchsortedfirst(self%col,j,rowifirstk,rowilastk,searchk)


        if (searchk .le. rowilastk .and. self%col(searchk) .eq. j) then
            self%val(searchk) = v
            return
        else
            write(*,*) "error! in CSRmodules. There is no entry at",i,j
            stop
        end if
        return
    end subroutine update_c

    integer function get_nonzeronum(self) result(nonzeros)
        implicit none
        class(CSRcomplex)::self
        nonzeros = self%row(self%N+1)
        return
    end function

    subroutine boundcheck(N,i,j) !配列外参照をチェック
        implicit none
        integer,intent(in)::N,i,j
        if (i < 1 .or. i > N) then
            write(*,*) "error! i should be [1:N]."
            stop
        end if
        if (j < 1 .or. j > N) then
            write(*,*) "error! j should be [1:N]."
            stop
        end if
        return
    end subroutine

    subroutine searchsortedfirst(vec,i,istart,iend,searchk) 
        implicit none
        integer,intent(in)::vec(:)
        integer,intent(in)::i,istart,iend
        integer,intent(out)::searchk
        integer::k
        if (iend - istart < 0) then
            searchk = istart
            return
        end if

        searchk = iend+1!ubound(vec,1)+1
        do k=istart,iend
            if (vec(k) .eq. i) then
                searchk = k
                return
            end if
        end do

        return
    end subroutine searchsortedfirst

end module CSRmodules

最初のprogram mainと組み合われば、テストできると思います。
MKLを使っていますので、コンパイルにはMKLが必要です。

14
13
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
14
13