LoginSignup
2
0

More than 3 years have passed since last update.

mpi並列化楽々化計画 in Fortran part2: 何も考えずにMPI並列化したい!

Posted at

はじめに

Fortranを使う人に簡単にMPI並列化を使用するための方法をお伝えします。
私自身はFortranを長いこと使ってきまして、特に不便も感じずに過ごしてきていましたが、他の言語(Julia)を知ると「Fortranだとめんどい!」という場面が生じるようになってきました。というわけで、FortranでDoループを簡単に並列化する方法について述べます。特に、オブジェクト指向Fortranを使うことで非常に簡単にMPI並列ができることをみます。

なお、mpi並列化楽々化計画 in Fortran: mpi_allgathervを使いこなす
の続編です。

コード

まず、これをみてください。

test.f90
program main
    use mpicalls
    implicit none
    integer::N
    real(8),allocatable::vals(:)
    integer::ierr
    type(gatherv)::g

    N = 12
    call mpi_init(ierr)
    allocate(vals(N))

    g = gatherv(N)

    vals= g%pmap_dble(test)

    if (g%get_myrank() == 0) then
        write(*,*) vals
    end if

    contains 

    function test(i) result(v)
        implicit none
        real(8)::v
        integer::i
        v = i + 1
        return
    end function 
end program main

このコード、vals= g%pmap_dble(test)だけで長さNのループを並列化して実行してその結果をvalsに格納してくれます。Juliaでのpmapみたいなものを作ってみました。このコードの関数test(i)はループ番号iを引数にした関数です。これだけではこのような引数が1個の場合にしか使えない関数に見えるかもしれませんが、そんなことはなく、

test2.f90
subroutine test2(i,a,b,v)
    implicit none
    integer,intent(in)::i
    real(8),intent(in)::a,b
    real(8),intent(out)::v
    v = i*a + b
    return
end subroutine

program main
    use mpicalls
    implicit none
    integer::N
    real(8),allocatable::vals(:)
    integer::ierr
    type(gatherv)::g
    real(8)::a,b
    external test2

    N = 12
    call mpi_init(ierr)
    allocate(vals(N))

    g = gatherv(N)

    a = 100d0
    b = 1d0

    vals= g%pmap_dble(wrap)

    if (g%get_myrank() == 0) then
        write(*,*) vals
    end if


    contains 

    function wrap(i) result(v)
        implicit none
        real(8)::v
        integer::i
        !write(*,*) i
        call test2(i,a,b,v)
        return
    end function 

    function test(i) result(v)
        implicit none
        real(8)::v
        integer::i
        v = i + 1
        return
    end function 
end program main

とwrapという関数を使えば、任意の変数の数のfunctionやsubroutineを突っ込むことができます。

つまり、このコードを使えば任意のfunctionやsubroutineをDoループ並列化できるわけです。

モジュール

ポイントはmpicallsというモジュールです。この中身は、

module.jl
module mpicalls
    implicit none
    include "mpif.h"
    private
    public::gatherv!,init_gatherv

    type gatherv
        integer::nprocs
        integer::myrank
        integer::N
        integer,allocatable::rcounts(:)
        integer,allocatable::displs(:)
        integer::count

        contains
        procedure:: allgatherv_dble =>  allgatherv_dble
        procedure:: allgatherv_complex =>  allgatherv_complex
        procedure:: get_start
        procedure:: get_end
        procedure:: pmap_dble
        procedure::get_myrank
    end type gatherv

    interface gatherv
        module procedure::init_gatherv
    end interface gatherv

    contains

    integer function get_myrank(self) result(myrank)
        implicit none
        class(gatherv)::self
        myrank = self%myrank
        return
    end function get_myrank

    integer function get_start(self) result(ista)
        implicit none
        class(gatherv)::self
        ista = self%displs(self%myrank+1) + 1
        return
    end function get_start

    integer function get_end(self) result(iend)
        implicit none
        class(gatherv)::self
        iend = self%displs(self%myrank+1+1)
        return
    end function get_end   

    type(gatherv) function init_gatherv(N) result(result)
        implicit none
        integer,intent(in)::N
        integer::ierr,nprocs,myrank,i
        integer::count,m

        call mpi_comm_rank(mpi_comm_world,myrank,ierr)
        call mpi_comm_size(mpi_comm_world,nprocs,ierr)
        result%nprocs = nprocs
        result%myrank = myrank
        result%N = N
        allocate(result%rcounts(1:nprocs))
        allocate(result%displs(1:nprocs+1))

        count = 0
        m = mod(N,nprocs) !N*nprocs + m
        count = (N-m)/nprocs
        if (mod(myrank-1+nprocs,nprocs)+1 <= m) then 
            count = count + 1
        end if
        result%count = count
        call mpi_allgather(count,1,mpi_integer,result%rcounts,1,mpi_integer,mpi_comm_world,ierr)

        result%displs(1) = 0
        do i=1,nprocs
            result%displs(i+1) = result%displs(i) + result%rcounts(i)
        end do

        return
    end function init_gatherv

    subroutine allgatherv_dble(self,v_ip,v) 
        implicit none
        class(gatherv)::self
        real(8),intent(in)::v_ip(:)
        real(8),intent(out)::v(:)
        integer::ierr
        call mpi_allgatherv(v_ip(1:self%count),self%count,mpi_double_precision,v, &
            self%rcounts,self%displs,mpi_double_precision,mpi_comm_world,ierr)
        return
    end subroutine allgatherv_dble

    subroutine allgatherv_complex(self,v_ip,v) 
        implicit none
        class(gatherv)::self
        complex(8),intent(in)::v_ip(:)
        complex(8),intent(out)::v(:)
        integer::ierr
        call mpi_allgatherv(v_ip(1:self%count),self%count,mpi_double_complex,v, &
            self%rcounts,self%displs,mpi_double_complex,mpi_comm_world,ierr)
    end subroutine allgatherv_complex  

    function pmap_dble(self,func) result(results)
        implicit none

        interface 
            function func(i)
                integer::i
                real(8)::func
            end function
        end interface 
        class(gatherv)::self
        integer::i
        integer::N
        integer::count
        real(8)::results(self%N),vec_tmp(self%count)
        N = self%N

        count = 0
        do i=self%get_start(),self%get_end()
            count = count + 1
            vec_tmp(count) = func(i)
        end do

        call self%allgatherv_dble(vec_tmp,results)
        return

    end function

end module mpicalls

となっています。これは、mpi並列化楽々化計画 in Fortran: mpi_allgathervを使いこなすで作ったモジュールに pmap_dbleという関数を追加しただけです。これによって、オブジェクトgがうまいことやってくれて、Doループの並列化ができます。なお、mpi_allgathervを使っていますから、ループの数が並列数で割り切れていなくてもうまいことやってくれます。

全体のコード

全体のコードは

main.f90
module mpicalls
    implicit none
    include "mpif.h"
    private
    public::gatherv!,init_gatherv

    type gatherv
        integer::nprocs
        integer::myrank
        integer::N
        integer,allocatable::rcounts(:)
        integer,allocatable::displs(:)
        integer::count

        contains
        procedure:: allgatherv_dble =>  allgatherv_dble
        procedure:: allgatherv_complex =>  allgatherv_complex
        procedure:: get_start
        procedure:: get_end
        procedure:: pmap_dble
        procedure::get_myrank
    end type gatherv

    interface gatherv
        module procedure::init_gatherv
    end interface gatherv

    contains

    integer function get_myrank(self) result(myrank)
        implicit none
        class(gatherv)::self
        myrank = self%myrank
        return
    end function get_myrank

    integer function get_start(self) result(ista)
        implicit none
        class(gatherv)::self
        ista = self%displs(self%myrank+1) + 1
        return
    end function get_start

    integer function get_end(self) result(iend)
        implicit none
        class(gatherv)::self
        iend = self%displs(self%myrank+1+1)
        return
    end function get_end   

    type(gatherv) function init_gatherv(N) result(result)
        implicit none
        integer,intent(in)::N
        integer::ierr,nprocs,myrank,i
        integer::count,m

        call mpi_comm_rank(mpi_comm_world,myrank,ierr)
        call mpi_comm_size(mpi_comm_world,nprocs,ierr)
        result%nprocs = nprocs
        result%myrank = myrank
        result%N = N
        allocate(result%rcounts(1:nprocs))
        allocate(result%displs(1:nprocs+1))

        count = 0
        m = mod(N,nprocs) !N*nprocs + m
        count = (N-m)/nprocs
        if (mod(myrank-1+nprocs,nprocs)+1 <= m) then 
            count = count + 1
        end if
        result%count = count
        call mpi_allgather(count,1,mpi_integer,result%rcounts,1,mpi_integer,mpi_comm_world,ierr)

        result%displs(1) = 0
        do i=1,nprocs
            result%displs(i+1) = result%displs(i) + result%rcounts(i)
        end do

        return
    end function init_gatherv

    subroutine allgatherv_dble(self,v_ip,v) 
        implicit none
        class(gatherv)::self
        real(8),intent(in)::v_ip(:)
        real(8),intent(out)::v(:)
        integer::ierr
        call mpi_allgatherv(v_ip(1:self%count),self%count,mpi_double_precision,v, &
            self%rcounts,self%displs,mpi_double_precision,mpi_comm_world,ierr)
        return
    end subroutine allgatherv_dble

    subroutine allgatherv_complex(self,v_ip,v) 
        implicit none
        class(gatherv)::self
        complex(8),intent(in)::v_ip(:)
        complex(8),intent(out)::v(:)
        integer::ierr
        call mpi_allgatherv(v_ip(1:self%count),self%count,mpi_double_complex,v, &
            self%rcounts,self%displs,mpi_double_complex,mpi_comm_world,ierr)
    end subroutine allgatherv_complex  

    function pmap_dble(self,func) result(results)
        implicit none

        interface 
            function func(i)
                integer::i
                real(8)::func
            end function
        end interface 
        class(gatherv)::self
        integer::i
        integer::N
        integer::count
        real(8)::results(self%N),vec_tmp(self%count)
        N = self%N

        count = 0
        do i=self%get_start(),self%get_end()
            count = count + 1
            vec_tmp(count) = func(i)
        end do

        call self%allgatherv_dble(vec_tmp,results)

        return

    end function

end module mpicalls


subroutine test2(i,a,b,v)
    implicit none
    integer,intent(in)::i
    real(8),intent(in)::a,b
    real(8),intent(out)::v
    v = i*a + b
    return
end subroutine

program main
    use mpicalls
    implicit none
    integer::N
    real(8),allocatable::vals(:)
    integer::ierr
    type(gatherv)::g
    real(8)::a,b
    external test2

    N = 12
    call mpi_init(ierr)
    allocate(vals(N))

    g = gatherv(N)

    vals= g%pmap_dble(test)

    if (g%get_myrank() == 0) then
        write(*,*) vals
    end if

    a = 100d0
    b = 1d0

    vals= g%pmap_dble(wrap)

    if (g%get_myrank() == 0) then
        write(*,*) vals
    end if


    contains 

    function wrap(i) result(v)
        implicit none
        real(8)::v
        integer::i
        call test2(i,a,b,v)
        return
    end function 

    function test(i) result(v)
        implicit none
        real(8)::v
        integer::i
        v = i + 1
        return
    end function 
end program main

です。FortranでMPIで困っている人の助けになれば幸いです。

2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0