4
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

pythonでのMPI並列化

Last updated at Posted at 2023-01-28

pythonでMPI通信を行うときはmpi4pyというパッケージを利用します。

mpi4pyのインストール

普通のやり方

pip install mpi4py

conda install -c conda-forge mpi4py

でインストールできます。

オフラインでのやり方

並列化計算をするようなクラスタコンピューターはセキュリティのためにオフラインになっていることが多いです。そのような場合は手動でインストールします。
まずは

からmpi4py-(バージョン).tar.gzをダウンロードして並列計算を行うコンピューターに送ります。そこで

tar zxvf mpi4py-(バージョン).tar.gz

とファイルを解凍してできたフォルダmpi4py-(バージョン)に入ります。そのフォルダに入っているsetup.py

python setup.py install

と実行すればmpi4pyモジュールが使えるようになります。

実行

使用するマシンの仕様にも依りますが

mpirun -np (コア数) python (実行コード)

などで実行できます。

基本的な使い方

クラスMPI.COMM_WORLDに必要な変数や関数が大体入っています。

例えば使用するコア数や自分のコアの番号は

mpi.py
size = MPI.COMM_WORLD.Get_size()
rank = MPI.COMM_WORLD.Get_rank()

で取得できます。

テストコード

helloMPI.py
from mpi4py import MPI

comm = MPI.COMM_WORLD 

size = comm.Get_size()
rank = comm.Get_rank()

print("Hello world {0} / {1}".format(rank, size))

10並列での実行結果:

Hello world 1 / 10
Hello world 6 / 10
Hello world 0 / 10
Hello world 5 / 10
Hello world 4 / 10
Hello world 3 / 10
Hello world 8 / 10
Hello world 2 / 10
Hello world 7 / 10
Hello world 9 / 10

Allreduce

allreduce.py
from mpi4py import MPI
import numpy as np


comm = MPI.COMM_WORLD
size = comm.Get_size()
rank = comm.Get_rank()

sendbuf = rank*np.ones(10, dtype="float64")
recvbuf = np.empty(sendbuf.shape, dtype="float64")

comm.Allreduce(sendbuf, recvbuf, MPI.SUM)


if(rank == 0):
    print("procs:{0}".format(size))
    print(recvbuf)

4並列での実行結果

procs:4
[6. 6. 6. 6. 6. 6. 6. 6. 6. 6.]

配列の分割

配列の和をとるときに次のクラスassignを使って各コアが計算する部分配列を定めて分割して和を計算します。例としてすべて要素が1の$2\times 2$行列の配列を作り、行列の和を計算しました。

Allreduce.py
from mpi4py import MPI
import numpy as np
import itertools

class assign:

    def __init__(self,procs,size):
        self.size = size
        self.procs = procs
        self.num = [ size//procs + (1 if i<size%procs else 0) for i in range(procs) ]
        self.id = [0] + list(itertools.accumulate( self.num))

size = 10
arr = np.ones((size,2,2))

comm = MPI.COMM_WORLD
procs = comm.Get_size()
rank = comm.Get_rank()

ASGN = assign(procs,size)

sendbuf = np.sum(arr[ASGN.id[rank]:ASGN.id[rank+1]],axis=0)
recvbuf =  np.empty(sendbuf.shape)

comm.Allreduce(sendbuf, recvbuf, MPI.SUM)

if(rank == 0):
    print("procs:{0}".format(procs))
    print(recvbuf)
    print(np.sum(arr,axis=0))

出力:

procs:4
[[10. 10.]
 [10. 10.]]
[[10. 10.]
 [10. 10.]]

参考資料

4
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?