LoginSignup
5
2

More than 1 year has passed since last update.

Fortranで機種依存性をなんとかする:cmakeとプリプロセッサの使用

Last updated at Posted at 2022-12-07

Fortranでコードを書くときに面倒だなと思うことの一つに、機種依存性です。例えばMKLが入っている場合とそうでない場合などでコードを書き直すのは面倒なわけです。あとはgfortranなのかintel fortranなのかなどのコンパイルの種類や、コンパイルオプションなども機種依存しています。これを解決するための方法の一つが以前の記事です。この記事の件をさらに進めて、なるべく楽にコンパイルする方法について模索しました。

今回は、以前のFortranで疎行列:オブジェクト指向Fortranで楽々プログラミングの記事を改良します。この記事のコードはMKLがインストールされていないと動かなかったのですが、MKLが入っていてもいなくてもコンパイルできるように、プリプロセッサとcmakeを使います。

プリプロセッサ

プリプロセッサというものは、古の時代(いつかは未定義)ではよく使われているように思いますが、最近の記事だとほとんど見ないものかなと思います。しかし、使われているところには使われていまして、新しくFortranを書く人は知らないまま、ということは多いような気がします。このプリプロセッサというものは、ソースコードに

#ifdef MKL
        call mkl_zcsrgemv("N", ubound(x,1), A%val, A%row, A%col, x, y)
#else
        call matrix_vector_c(ubound(x,1), A%val, A%row, A%col, x, y)
#endif

のようにifdefをつけておくと、コンパイル時に

gfortran -DMKL

などとすると上が呼ばれ、DMKLをつけないと下が呼ばれる、というような分岐を行うことができます。これによって、例えば、MKLが入っているかどうかでソースコードを変更することができるわけです。

一つのはまりポイントとしては、gfortranを使う場合、ソースコードは.f90ではなく.F90とfを大文字にしないとプリプロセッサが使えない、ということです。とりあえずコードは.F90としておきましょう。

cmake

上のプリプロセッサをcmakeで使いたい場合には、

target_compile_definitions(csrmodule PUBLIC MKL )

のようにすると、コンパイル時に-DMKLをつけてくれます。

コンパイルするコード

今回は、疎行列をCSR形式で扱うmoduleを作り、そのmoduleをライブラリ化して他のコードから呼べるようにします。
なお、コードはこちらにも一式置いておきました。

まず、mainのコードを

program main
    use csrmodule
    implicit none
    type(CSRcomplex)::H
    integer::N
    complex(8)::v,alpha,beta
    integer::i,j
    real(8)::mu,t
    complex(8),allocatable::x(:)
    complex(8),allocatable::y(:)
    complex(8),allocatable::ytemp(:)
    complex(8),allocatable::z(:)
    complex(8),allocatable::Hdense(:,:)
    complex(8),allocatable::ytemp2(:)

    N = 10
    v = -1d0
    t = -1d0
    mu = -1.5d0
    allocate(Hdense(2*N,2*N))

    H = CSRcomplex(2*N)
    do i = 1,N
        j = i
        v = -mu

        call H%set(v,i,j)
        Hdense(i,j) = v
        call H%set(-v,i+N,j+N)
        Hdense(i+N,j+N) = -v

        v = t
        j = i+1
        if (j > 0 .and. j < N+1) then
            call H%set(v,i,j)
            Hdense(i,j) = v
            call H%set(v,j,i)
            Hdense(j,i) = v

            call H%set(-v,i+N,j+N)
            Hdense(i+N,j+N) = -v
            call H%set(-v,j+N,i+N)
            Hdense(j+N,i+N) = -v
        end if

        j = i-1
        v =t

        if (j > 0 .and. j < N+1) then
            call H%set(v,i,j)
            Hdense(i,j) = v
            call H%set(v,j,i)
            Hdense(j,i) = v

            call H%set(-v,i+N,j+N)
            Hdense(i+N,j+N) = -v
            call H%set(-v,j+N,i+N)
            Hdense(j+N,i+N) = -v
        end if

        j = i+N
        v = 0.5d0
        call H%set(v,i,j)
        Hdense(i,j) = v
        call H%set(v,j,i)
        Hdense(j,i) = v

    end do

    call H%print()

    do i=1,N
        j = i+N
        v = 0.6d0
        call H%update(v,i,j)
        Hdense(i,j) = v
        call H%update(v,j,i)
        Hdense(j,i) = v


    end do
    call H%print()

    allocate(x(1:2*N))
    x = 0d0
    allocate(y(1:2*N))
    allocate(ytemp(1:2*N))
    allocate(ytemp2(1:2*N))
    y = 0d0
    ytemp = 0d0
    allocate(z(1:2*N))
    z(5) = 10d0
    x(3) = 1d0
    call H%matmul(x,y)
    !write(*,*) "matmul(x,y) ", y


    ytemp = H*x
    write(*,*) "H*x", ytemp

    ytemp2 = matmul(Hdense,x)
    do i=1,2*N
        write(*,*) i,ytemp(i),ytemp2(i)
    end do
    write(*,*) "diff = ",dot_product(ytemp-ytemp2,ytemp-ytemp2)/N




    alpha = 2d0
    beta = 3d0
    call H%matmul2(x,y,z,alpha,beta)
    !write(*,*) "y = alpha*A*x+beta*z",y


end program main

とします。moduleとしてcsrmoduleを呼んでいます。
そして、csrmodule

module csrmodule
    implicit none
    private

    public CSRcomplex
    public :: operator(*)

    type CSRcomplex
        private
        integer::N !正方行列を仮定している。
        complex(8),allocatable::val(:)
        integer,allocatable::row(:)
        integer,allocatable::col(:)

        contains
        procedure::set => set_c !(i,j)成分にvを代入 set(v,i,j)
        procedure::print => print_csr !行列の中身をprint
        procedure::nonzero => get_nonzeronum  !非ゼロ要素の数を数える
        procedure::update => update_c !(i,j)成分の値をアップデート。なければエラーで止まる。
        procedure::matmul => matmul_axy  !y = A*x: matmul(x,y)
        procedure::matmul2 => matmul_axy2 !y = alpha*A*x +beta*b: matmul2(x,y,b,alpha,beta)
    end type

    interface CSRcomplex
        module procedure::init_CSRcomplex
        module procedure::init_CSRcomplex_initialnum
    end interface CSRcomplex

    interface insert_element
        module procedure::insert_c
        module procedure::insert_int
    end interface

    interface operator(*)
        module procedure mult
    end interface


    contains 



    type(CSRcomplex) function init_CSRcomplex(N) result(A)!疎行列を初期化
        implicit none
        integer,intent(in)::N      
        A = init_CSRcomplex_initialnum(N,N) !配列の初期の長さをNとした。
        return
    end function

    type(CSRcomplex) function init_CSRcomplex_initialnum(N,initialnum) result(A)!疎行列を初期化。配列の初期の長さはinitialnumに設定。
        implicit none
        integer,intent(in)::N,initialnum

        A%N = N
        allocate(A%row(N+1))
        allocate(A%val(initialnum)) !長さinitialnumに初期化
        allocate(A%col(initialnum)) !長さinitialnumに初期化
        A%val = 0d0
        A%col = 0
        A%row = 1
    end function


    subroutine matrix_vector_c(N,val,row,col,x,y)
        implicit none
        integer,intent(in)::N
        complex(8),intent(in)::val(:)
        integer,intent(in)::row(:)
        integer,intent(in)::col(:)
        complex(8),intent(in)::x(N)
        complex(8),intent(out)::y(N)
        integer::i,j
        do i=1,N
            y(i) = 0.0d0
            do j=row(i), row(i+1)-1
                y(i) = y(i)+val(j)*x(col(j))
            end do
        end do
    end subroutine


    function mult(A,x) result(y)
        implicit none
        type(CSRcomplex),intent(in)::A
        complex(8),intent(in)::x(:)
        complex(8)::y(ubound(x,1))

        if (A%N .ne. ubound(x,1)) then
            write(*,*) "error in CSRmodules! size mismatch"
            stop
        end if
        y(1:ubound(x,1)) = 0d0

#ifdef MKL
        call mkl_zcsrgemv("N", ubound(x,1), A%val, A%row, A%col, x, y)
#else
        call matrix_vector_c(ubound(x,1), A%val, A%row, A%col, x, y)
#endif

    end function

    subroutine matmul_axy(self,x,y) 
        implicit none
        class(CSRcomplex),intent(in)::self
        complex(8),intent(in)::x(:)
        complex(8),intent(out)::y(*)
        !allocate(y(1:ubound(x,1)))
        if (self%N .ne. ubound(x,1)) then
            write(*,*) "error in CSRmodules! size mismatch"
            stop
        end if
        y(1:ubound(x,1)) = 0d0

#ifdef MKL
        call mkl_zcsrgemv("N", ubound(x,1), self%val, self%row, self%col, x, y)
#else
        call matrix_vector_c(ubound(x,1), self%val, self%row, self%col, x, y)
#endif


        return
    end subroutine

    subroutine matmul_axy2(self,x,y,b,alpha,beta) !y = alpha*A*x+ beta*b 
        implicit none
        class(CSRcomplex),intent(in)::self
        complex(8),intent(in)::x(:)
        complex(8),intent(in)::b(:)
        complex(8),intent(in)::alpha,beta
        complex(8),intent(out)::y(*)
        !allocate(y(1:ubound(x,1)))
        if (self%N .ne. ubound(x,1)) then
            write(*,*) "error in CSRmodules! size mismatch"
            stop
        end if
        y(1:ubound(x,1)) = 0d0
#ifdef MKL
        call mkl_zcsrgemv("N", ubound(x,1), self%val, self%row, self%col, x, y)
#else
        call matrix_vector_c(ubound(x,1), self%val, self%row, self%col, x, y)
#endif

        y(1:ubound(x,1)) = alpha*y(1:ubound(x,1)) + beta*b(1:ubound(x,1))

        return
    end subroutine    


    subroutine print_csr(self)
        implicit none
        class(CSRcomplex)::self
        integer::i,j,k


        do i=1,self%N
            do k=self%row(i),self%row(i+1) -1
                j = self%col(k)
                write(*,*) i,j,self%val(k)
            end do
        end do

        return
    end subroutine print_csr

    subroutine set_c(self,v,i,j)
        implicit none
        class(CSRcomplex)::self
        complex(8),intent(in)::v
        integer,intent(in)::i,j
        integer::rowifirstk,rowilastk
        integer::searchk
        integer::nonzeros
        integer::m

        call boundcheck(self%N,i,j)
        nonzeros = self%nonzero()
        rowifirstk = self%row(i)
        rowilastk = self%row(i+1)-1 
        call searchsortedfirst(self%col(1:nonzeros),j,rowifirstk,rowilastk,searchk)
        !write(*,*) "s",searchk,i,j,rowifirstk,rowilastk 

        if (searchk .le. rowilastk .and. self%col(searchk) .eq. j) then
            self%val(searchk) = v
            return
        end if

        if (abs(v) .ne. 0d0) then

            call insert_element(self%col,searchk,j,nonzeros)
            call insert_element(self%val,searchk,v,nonzeros)
            do m=i+1,self%N+1
                self%row(m) = self%row(m) + 1
            end do


        end if

        return
    end subroutine set_c

    subroutine insert_c(vec,pos,item,nz)
        implicit none
        complex(8),intent(inout),allocatable::vec(:)
        integer,intent(in)::pos,nz
        complex(8),intent(in)::item
        integer::vlength
        complex(8),allocatable::temp(:)

        vlength = ubound(vec,1)

        if (nz .ge. vlength) then !確保している配列valの長さが足りない時
            allocate(temp(vlength+1))
            if (pos > 1) then
                temp(1:pos-1)= vec(1:pos-1)
            end if
            temp(pos+1:nz+1) = vec(pos:nz)
            temp(pos) = item
            deallocate(vec)
            allocate(vec(vlength+vlength))
            vec(1:vlength+1) = temp(1:vlength+1)
            vec(vlength+2:vlength+vlength) = 0d0
        else
            vec(pos+1:nz+1) = vec(pos:nz)
            vec(pos) = item
        end if

        return
    end subroutine insert_c

    subroutine insert_int(vec,pos,item,nz)
        implicit none
        integer,intent(inout),allocatable::vec(:)
        integer,intent(in)::pos,item,nz
        integer::vlength
        integer,allocatable::temp(:)

        vlength = ubound(vec,1)

        if (nz .ge. vlength) then !確保している配列valの長さが足りない時
            allocate(temp(vlength+1))
            if (pos > 1) then
                temp(1:pos-1)= vec(1:pos-1)
            end if
            temp(pos+1:nz+1) = vec(pos:nz)
            temp(pos) = item
            deallocate(vec)
            allocate(vec(vlength+vlength))
            vec(1:vlength+1) = temp(1:vlength+1)
            vec(vlength+2:vlength+vlength) = 0
        else
            vec(pos+1:nz+1) = vec(pos:nz)
            vec(pos) = item
        end if

        return
    end subroutine insert_int

    subroutine update_c(self,v,i,j)
        implicit none
        class(CSRcomplex)::self
        complex(8),intent(in)::v
        integer,intent(in)::i,j
        integer::rowifirstk,rowilastk
        integer::searchk

        call boundcheck(self%N,i,j)
        rowifirstk = self%row(i)
        rowilastk = self%row(i+1)-1 
        call searchsortedfirst(self%col,j,rowifirstk,rowilastk,searchk)


        if (searchk .le. rowilastk .and. self%col(searchk) .eq. j) then
            self%val(searchk) = v
            return
        else
            write(*,*) "error! in CSRmodules. There is no entry at",i,j
            stop
        end if
        return
    end subroutine update_c

    integer function get_nonzeronum(self) result(nonzeros)
        implicit none
        class(CSRcomplex)::self
        nonzeros = self%row(self%N+1)
        return
    end function

    subroutine boundcheck(N,i,j) !配列外参照をチェック
        implicit none
        integer,intent(in)::N,i,j
        if (i < 1 .or. i > N) then
            write(*,*) "error! i should be [1:N]."
            stop
        end if
        if (j < 1 .or. j > N) then
            write(*,*) "error! j should be [1:N]."
            stop
        end if
        return
    end subroutine

    subroutine searchsortedfirst(vec,i,istart,iend,searchk) 
        implicit none
        integer,intent(in)::vec(:)
        integer,intent(in)::i,istart,iend
        integer,intent(out)::searchk
        integer::k
        if (iend - istart < 0) then
            searchk = istart
            return
        end if

        searchk = iend+1!ubound(vec,1)+1
        do k=istart,iend
            if (vec(k) .eq. i) then
                searchk = k
                return
            end if
        end do

        return
    end subroutine searchsortedfirst

end module 

です。ifdefを使うことによって、MKLの有無で行列ベクトル積のサブルーチンを変更しています。

これをコンパイルするcmakeファイルが、

CMakeList.txt
cmake_minimum_required(VERSION 3.15)
project(csr Fortran)

set (CMAKE_Fortran_MODULE_DIRECTORY ${CMAKE_BINARY_DIR}/modules)
add_library(csrmodule STATIC )
set(srcdir src)

option(MKLUSE "Use MKL or not" OFF)
if(MKLUSE)
set(MKL_INTERFACE lp64)
set(MKL_INTERFACE_LAYER "_lp64")
find_package(MKL REQUIRED)
set(lapacklink MKL::MKL)
target_compile_definitions(csrmodule PUBLIC MKL )
else()
endif()

set(linklibraries ${lapacklink})


target_sources(csrmodule 
PRIVATE
${srcdir}/csrmodule.F90
)

target_link_libraries(csrmodule PRIVATE ${linklibraries} )


add_executable(test)

target_sources(test 
PRIVATE
${srcdir}/main.f90
)

target_link_libraries(test csrmodule )

これです。

target_link_libraries(test csrmodule )

の部分で、testというコードにライブラリcsrmoduleをリンクしています。この部分をご自分のコードに変更すれば、簡単にcsrmoduleを使うことができるわけです。

5
2
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
2