はじめに
Fortranを使う人に簡単にMPI並列化を使用するための方法をお伝えします。
私自身はFortranを長いこと使ってきまして、特に不便も感じずに過ごしてきていましたが、他の言語(Julia)を知ると「Fortranだとめんどい!」という場面が生じるようになってきました。というわけで、FortranでDoループを簡単に並列化する方法について述べます。特に、オブジェクト指向Fortranを使うことで非常に簡単にMPI並列ができることをみます。
なお、mpi並列化楽々化計画 in Fortran: mpi_allgathervを使いこなす
の続編です。
コード
まず、これをみてください。
program main
use mpicalls
implicit none
integer::N
real(8),allocatable::vals(:)
integer::ierr
type(gatherv)::g
N = 12
call mpi_init(ierr)
allocate(vals(N))
g = gatherv(N)
vals= g%pmap_dble(test)
if (g%get_myrank() == 0) then
write(*,*) vals
end if
contains
function test(i) result(v)
implicit none
real(8)::v
integer::i
v = i + 1
return
end function
end program main
このコード、vals= g%pmap_dble(test)
だけで長さNのループを並列化して実行してその結果をvalsに格納してくれます。Juliaでのpmapみたいなものを作ってみました。このコードの関数test(i)
はループ番号iを引数にした関数です。これだけではこのような引数が1個の場合にしか使えない関数に見えるかもしれませんが、そんなことはなく、
subroutine test2(i,a,b,v)
implicit none
integer,intent(in)::i
real(8),intent(in)::a,b
real(8),intent(out)::v
v = i*a + b
return
end subroutine
program main
use mpicalls
implicit none
integer::N
real(8),allocatable::vals(:)
integer::ierr
type(gatherv)::g
real(8)::a,b
external test2
N = 12
call mpi_init(ierr)
allocate(vals(N))
g = gatherv(N)
a = 100d0
b = 1d0
vals= g%pmap_dble(wrap)
if (g%get_myrank() == 0) then
write(*,*) vals
end if
contains
function wrap(i) result(v)
implicit none
real(8)::v
integer::i
!write(*,*) i
call test2(i,a,b,v)
return
end function
function test(i) result(v)
implicit none
real(8)::v
integer::i
v = i + 1
return
end function
end program main
とwrapという関数を使えば、任意の変数の数のfunctionやsubroutineを突っ込むことができます。
つまり、このコードを使えば任意のfunctionやsubroutineをDoループ並列化できるわけです。
モジュール
ポイントはmpicallsというモジュールです。この中身は、
module mpicalls
implicit none
include "mpif.h"
private
public::gatherv!,init_gatherv
type gatherv
integer::nprocs
integer::myrank
integer::N
integer,allocatable::rcounts(:)
integer,allocatable::displs(:)
integer::count
contains
procedure:: allgatherv_dble => allgatherv_dble
procedure:: allgatherv_complex => allgatherv_complex
procedure:: get_start
procedure:: get_end
procedure:: pmap_dble
procedure::get_myrank
end type gatherv
interface gatherv
module procedure::init_gatherv
end interface gatherv
contains
integer function get_myrank(self) result(myrank)
implicit none
class(gatherv)::self
myrank = self%myrank
return
end function get_myrank
integer function get_start(self) result(ista)
implicit none
class(gatherv)::self
ista = self%displs(self%myrank+1) + 1
return
end function get_start
integer function get_end(self) result(iend)
implicit none
class(gatherv)::self
iend = self%displs(self%myrank+1+1)
return
end function get_end
type(gatherv) function init_gatherv(N) result(result)
implicit none
integer,intent(in)::N
integer::ierr,nprocs,myrank,i
integer::count,m
call mpi_comm_rank(mpi_comm_world,myrank,ierr)
call mpi_comm_size(mpi_comm_world,nprocs,ierr)
result%nprocs = nprocs
result%myrank = myrank
result%N = N
allocate(result%rcounts(1:nprocs))
allocate(result%displs(1:nprocs+1))
count = 0
m = mod(N,nprocs) !N*nprocs + m
count = (N-m)/nprocs
if (mod(myrank-1+nprocs,nprocs)+1 <= m) then
count = count + 1
end if
result%count = count
call mpi_allgather(count,1,mpi_integer,result%rcounts,1,mpi_integer,mpi_comm_world,ierr)
result%displs(1) = 0
do i=1,nprocs
result%displs(i+1) = result%displs(i) + result%rcounts(i)
end do
return
end function init_gatherv
subroutine allgatherv_dble(self,v_ip,v)
implicit none
class(gatherv)::self
real(8),intent(in)::v_ip(:)
real(8),intent(out)::v(:)
integer::ierr
call mpi_allgatherv(v_ip(1:self%count),self%count,mpi_double_precision,v, &
self%rcounts,self%displs,mpi_double_precision,mpi_comm_world,ierr)
return
end subroutine allgatherv_dble
subroutine allgatherv_complex(self,v_ip,v)
implicit none
class(gatherv)::self
complex(8),intent(in)::v_ip(:)
complex(8),intent(out)::v(:)
integer::ierr
call mpi_allgatherv(v_ip(1:self%count),self%count,mpi_double_complex,v, &
self%rcounts,self%displs,mpi_double_complex,mpi_comm_world,ierr)
end subroutine allgatherv_complex
function pmap_dble(self,func) result(results)
implicit none
interface
function func(i)
integer::i
real(8)::func
end function
end interface
class(gatherv)::self
integer::i
integer::N
integer::count
real(8)::results(self%N),vec_tmp(self%count)
N = self%N
count = 0
do i=self%get_start(),self%get_end()
count = count + 1
vec_tmp(count) = func(i)
end do
call self%allgatherv_dble(vec_tmp,results)
return
end function
end module mpicalls
となっています。これは、mpi並列化楽々化計画 in Fortran: mpi_allgathervを使いこなすで作ったモジュールに pmap_dble
という関数を追加しただけです。これによって、オブジェクトgがうまいことやってくれて、Doループの並列化ができます。なお、mpi_allgatherv
を使っていますから、ループの数が並列数で割り切れていなくてもうまいことやってくれます。
全体のコード
全体のコードは
module mpicalls
implicit none
include "mpif.h"
private
public::gatherv!,init_gatherv
type gatherv
integer::nprocs
integer::myrank
integer::N
integer,allocatable::rcounts(:)
integer,allocatable::displs(:)
integer::count
contains
procedure:: allgatherv_dble => allgatherv_dble
procedure:: allgatherv_complex => allgatherv_complex
procedure:: get_start
procedure:: get_end
procedure:: pmap_dble
procedure::get_myrank
end type gatherv
interface gatherv
module procedure::init_gatherv
end interface gatherv
contains
integer function get_myrank(self) result(myrank)
implicit none
class(gatherv)::self
myrank = self%myrank
return
end function get_myrank
integer function get_start(self) result(ista)
implicit none
class(gatherv)::self
ista = self%displs(self%myrank+1) + 1
return
end function get_start
integer function get_end(self) result(iend)
implicit none
class(gatherv)::self
iend = self%displs(self%myrank+1+1)
return
end function get_end
type(gatherv) function init_gatherv(N) result(result)
implicit none
integer,intent(in)::N
integer::ierr,nprocs,myrank,i
integer::count,m
call mpi_comm_rank(mpi_comm_world,myrank,ierr)
call mpi_comm_size(mpi_comm_world,nprocs,ierr)
result%nprocs = nprocs
result%myrank = myrank
result%N = N
allocate(result%rcounts(1:nprocs))
allocate(result%displs(1:nprocs+1))
count = 0
m = mod(N,nprocs) !N*nprocs + m
count = (N-m)/nprocs
if (mod(myrank-1+nprocs,nprocs)+1 <= m) then
count = count + 1
end if
result%count = count
call mpi_allgather(count,1,mpi_integer,result%rcounts,1,mpi_integer,mpi_comm_world,ierr)
result%displs(1) = 0
do i=1,nprocs
result%displs(i+1) = result%displs(i) + result%rcounts(i)
end do
return
end function init_gatherv
subroutine allgatherv_dble(self,v_ip,v)
implicit none
class(gatherv)::self
real(8),intent(in)::v_ip(:)
real(8),intent(out)::v(:)
integer::ierr
call mpi_allgatherv(v_ip(1:self%count),self%count,mpi_double_precision,v, &
self%rcounts,self%displs,mpi_double_precision,mpi_comm_world,ierr)
return
end subroutine allgatherv_dble
subroutine allgatherv_complex(self,v_ip,v)
implicit none
class(gatherv)::self
complex(8),intent(in)::v_ip(:)
complex(8),intent(out)::v(:)
integer::ierr
call mpi_allgatherv(v_ip(1:self%count),self%count,mpi_double_complex,v, &
self%rcounts,self%displs,mpi_double_complex,mpi_comm_world,ierr)
end subroutine allgatherv_complex
function pmap_dble(self,func) result(results)
implicit none
interface
function func(i)
integer::i
real(8)::func
end function
end interface
class(gatherv)::self
integer::i
integer::N
integer::count
real(8)::results(self%N),vec_tmp(self%count)
N = self%N
count = 0
do i=self%get_start(),self%get_end()
count = count + 1
vec_tmp(count) = func(i)
end do
call self%allgatherv_dble(vec_tmp,results)
return
end function
end module mpicalls
subroutine test2(i,a,b,v)
implicit none
integer,intent(in)::i
real(8),intent(in)::a,b
real(8),intent(out)::v
v = i*a + b
return
end subroutine
program main
use mpicalls
implicit none
integer::N
real(8),allocatable::vals(:)
integer::ierr
type(gatherv)::g
real(8)::a,b
external test2
N = 12
call mpi_init(ierr)
allocate(vals(N))
g = gatherv(N)
vals= g%pmap_dble(test)
if (g%get_myrank() == 0) then
write(*,*) vals
end if
a = 100d0
b = 1d0
vals= g%pmap_dble(wrap)
if (g%get_myrank() == 0) then
write(*,*) vals
end if
contains
function wrap(i) result(v)
implicit none
real(8)::v
integer::i
call test2(i,a,b,v)
return
end function
function test(i) result(v)
implicit none
real(8)::v
integer::i
v = i + 1
return
end function
end program main
です。FortranでMPIで困っている人の助けになれば幸いです。