41
41

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

MatlabやPythonのループは遅いって聞くけど本当?

Last updated at Posted at 2020-10-11

forループを回すと遅いらしい言語たち

コンパイル言語ではないにもかかわらず配列の扱いやすさによって数値計算で広く受け入れられているMatlab&Python(with numpy)はfor文で計算すると遅いのでそれをできる限り回避しよう!っていう言説をよく聞くと思います。
その影響はどれほどのものなのでしょうか。2次元の波動方程式の陽解法を使って検証してみました。

波動方程式の中心解法

支配方程式である波動方程式とその数値解法を見ていきましょう。

2次元の波動方程式は

\frac{\partial^2 u}{\partial t^2} - c^2\left(\frac{\partial^2 u}{\partial x^2}+\frac{\partial^2 u}{\partial y^2}\right)=0

であり、これの中心系の離散化は、

\frac{u^{n+1}_{i,j}-2u^n_{i,j}+u^{n-1}_{i,j}}{\Delta t^2} - c^2\left(\frac{u^n_{i+1,j}-2u^n_{i,j}+u^n_{i-1,j}}{\Delta x^2}+\frac{u^n_{i,j+1}-2u^n_{i,j}+u^n_{i,j-1}}{\Delta y^2}\right)=0

となります。
これを計算しやすいように整理すると、

u^{n+1}_{i,j}=2u^n_{i,j}-u^{n-1}_{i,j} +  c^2 \Delta t^2 \left(\frac{u^n_{i+1,j}-2u^n_{i,j}+u^n_{i-1,j}}{\Delta x^2}+\frac{u^n_{i,j+1}-2u^n_{i,j}+u^n_{i,j-1}}{\Delta y^2}\right)

となります。

境界条件はディリクレ条件であれば境界の値を代入するだけなので楽です。

Matlab

まずは適当にループを使って書いていきましょう。

c = 1.0;
xmin = 0.;
ymin = 0.;
xmax = 1.;
ymax = 1.; % 計算領域は[0,1],[0,1]
xmesh = 400; 
ymesh = 400; % 分割数はx,y軸ともに400
dx = (xmax-xmin)/xmesh;
dy = (ymax-ymin)/ymesh;
dt = 0.2*dx/c;
u0 = zeros(xmesh,ymesh); % u^{n-1}
u1 = zeros(xmesh,ymesh); % u^n
u2 = zeros(xmesh,ymesh); % u^{n+1}
u1(100:130,100:130)=1e-6;% 一定領域に初速を与えている。
x = xmin+dx/2:dx:xmax-dx/2;
y = ymin+dy/2:dy:ymax-dy/2;
t=0.;
tic %tic tocで時間を計測できる便利なやつ
while t<1.0
    for j = 2:ymesh-1
        for i = 2:xmesh-1
            u2(i,j) = 2*u1(i,j)-u0(i,j) + c*c*dt*dt*((u1(i+1,j)-2*u1(i,j)+u1(i-1,j))/(dx*dx) +(u1(i,j+1)-2*u1(i,j)+u1(i,j-1))/(dy*dy) );
        end
    end
    u0=u1;
    u1=u2;
    t = t+dt;
    %ディリクレ条件を与える
    for i=1:xmesh
        u1(i,1)=0.;
        u1(i,ymesh)=0.;
    end
    for j=1:ymesh
        u1(1,j)=0.;
        u1(xmesh,j)=0.;
    end
    
end
toc
[X,Y] = meshgrid(x,y);
mesh(X,Y,u1); %最後に描画

私の環境では経過時間は5.249230 秒でした。
これを時間部分以外ループなしで書いていきましょう。

c = 1.0;
xmin = 0.;
ymin = 0.;
xmax = 1.;
ymax = 1.; % 計算領域は[0,1],[0,1]
xmesh = 400; 
ymesh = 400; % 分割数はx,y軸ともに400
dx = (xmax-xmin)/xmesh;
dy = (ymax-ymin)/ymesh;
dt = 0.2*dx/c;
u0 = zeros(xmesh,ymesh); % u^{n-1}
u1 = zeros(xmesh,ymesh); % u^n
u2 = zeros(xmesh,ymesh); % u^{n+1}
u1(100:130,100:130)=1e-6;% 一定領域に初速を与えている。
x = xmin+dx/2:dx:xmax-dx/2;
y = ymin+dy/2:dy:ymax-dy/2;
t=0.;
tic %tic tocで時間を計測できる便利なやつ
while t<1.0
    u2(2:xmesh-1,2:ymesh-1) = 2*u1(2:xmesh-1,2:ymesh-1)-u0(2:xmesh-1,2:ymesh-1)+c*c*dt*dt*(diff(u1(:,2:ymesh-1),2,1)/(dx*dx)+diff(u1(2:xmesh-1,:),2,2)/(dy*dy));
    u0=u1;
    u1=u2;
    t = t+dt;
    %ディリクレ条件を与える
    u1(:,1)=0.;
    u1(:,ymesh)=0.;
    u1(1,:)=0.;
    u1(xmesh,:)=0.;
end
toc
[X,Y] = meshgrid(x,y);
mesh(X,Y,u1)

コードの見た目がスッキリしたように見えて書くときは少し頭を使います。diffを使った時のインデックスのズレなどを考える必要があるからです。
実行時間は2.734515秒でした。倍程度しか実行時間の差が開かず、予想以上に差が少ないっていう印象です。

Python

numpyが便利なのでよく使います。
ループ使って書いていきましょう。

import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import time
c = 1.0
xmin = 0.
ymin = 0.
xmax = 1.
ymax = 1.  # 計算領域は[0,1],[0,1]
xmesh = 400
ymesh = 400  # 分割数はx,y軸ともに400
dx = (xmax-xmin)/xmesh
dy = (ymax-ymin)/ymesh
dt = 0.2*dx/c
u0 = np.zeros((xmesh, ymesh))  # u^{n-1}
u1 = np.zeros((xmesh, ymesh))  # u^n
u2 = np.zeros((xmesh, ymesh))  # u^{n+1}
u1[100:130, 100:130] = 1e-6  # 一定領域に初速を与えている。
x = np.linspace(xmin+dx/2, xmax-dx/2, xmesh)
y = np.linspace(ymin+dy/2, ymax-dy/2, ymesh)
t = 0.
before = time.time()
while t < 1.:
    for j in range(1, xmesh-1):
        for i in range(1, ymesh-1):
            u2[i, j] = 2*u1[i, j]-u0[i, j]+c*c*dt*dt * \
                ((u1[i+1, j]-2.*u1[i, j]+u1[i-1, j])/(dx*dx) +
                 (u1[i, j+1]-2.*u1[i, j]+u1[i, j-1])/(dy*dy))

    u0 = deepcopy(u1)
    u1 = deepcopy(u2)
    t = t+dt
    # ディリクレ条件を与える
    for i in range(0,xmesh):
        u1[i, 0] = 0.
        u1[i, ymesh-1] = 0.
    for j in range(0,ymesh):
        u1[0, j] = 0.
        u1[xmesh-1, j] = 0.

print(time.time()-before)

X, Y = np.meshgrid(x, y)
fig = plt.figure()  # プロット領域の作成
ax = fig.gca(projection='3d')
ax.plot_surface(X, Y, u1)
plt.show()

計算が終わったのは880.909秒でした。
あまりにも遅い。

次はループなしのコード。

import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import time
c = 1.0
xmin = 0.
ymin = 0.
xmax = 1.
ymax = 1.  # 計算領域は[0,1],[0,1]
xmesh = 400
ymesh = 400  # 分割数はx,y軸ともに400
dx = (xmax-xmin)/xmesh
dy = (ymax-ymin)/ymesh
dt = 0.2*dx/c
u0 = np.zeros((xmesh, ymesh))  # u^{n-1}
u1 = np.zeros((xmesh, ymesh))  # u^n
u2 = np.zeros((xmesh, ymesh))  # u^{n+1}
u1[100:130, 100:130] = 1e-6  # 一定領域に初速を与えている。
x = np.linspace(xmin+dx/2, xmax-dx/2, xmesh)
y = np.linspace(ymin+dy/2, ymax-dy/2, ymesh)
t = 0.
before = time.time()
while t < 1.0:
    u2[1:xmesh-1, 1:ymesh-1] = 2*u1[1:xmesh-1, 1:ymesh-1]-u0[1:xmesh-1, 1:ymesh-1]+c*c*dt*dt * \
        (np.diff(u1[:, 1:ymesh-1], n=2, axis=0)/(dx*dx) +
         np.diff(u1[1: xmesh-1, :], n=2, axis=1)/(dy*dy))
    u0 = deepcopy(u1)
    u1 = deepcopy(u2)
    t = t+dt
    # ディリクレ条件を与える
    u1[:, 0] = 0.
    u1[:, ymesh-1] = 0.
    u1[0, :] = 0.
    u1[xmesh-1, :] = 0.

print(time.time()-before)

X, Y = np.meshgrid(x, y)
fig = plt.figure()  # プロット領域の作成
ax = fig.gca(projection='3d')
ax.plot_surface(X, Y, u1)
plt.show()

結果は5.539646625518799秒。pythonループ版に比べ100倍以上速くなっていますがMatlabに比べると結構時間かかりますね
書き方が悪いんでしょうか。

結論

Pythonで2重ループを振り回すのは極力控えましょう。異常に計算時間がかかります。Matlabなら少し遅くなるくらいで済みました。

2020/10/14 追記 Matlabに詳しい人が教えてくれました.Matlabは自動的にJITコンパイルされているみたいです.素晴らしいです.
https://jp.mathworks.com/products/matlab/matlab-execution-engine.html

おまけ

波の反射ってきれいですね

c = 1.0;
xmin = 0.;
ymin = 0.;
xmax = 1.;
ymax = 1.; % 計算領域は[0,1],[0,1]
xmesh = 100; 
ymesh = 100; % 分割数はx,y軸ともに400
dx = (xmax-xmin)/xmesh;
dy = (ymax-ymin)/ymesh;
dt = 0.1*dx/c;
u0 = zeros(xmesh,ymesh); % u^{n-1}
u1 = zeros(xmesh,ymesh); % u^n
u2 = zeros(xmesh,ymesh); % u^{n+1}

x = xmin+dx/2:dx:xmax-dx/2;
y = ymin+dy/2:dy:ymax-dy/2;
t=0.;
[X,Y] = meshgrid(x,y);
[smallX,smallY] = meshgrid(linspace(0,pi,21),linspace(0,pi,21));
u1(20:40,30:50)=2e-7*sin(smallX).*sin(smallY);% 一定領域に初速を与えている。
u1(60:80,40:60)=1e-7*sin(smallX).*sin(smallY);
mesh(X,Y,u1);
zlim([-1e-5 1e-5]);
tic %tic tocで時間を計測できる便利なやつ
while t<10.0
    u2(2:xmesh-1,2:ymesh-1) = 2*u1(2:xmesh-1,2:ymesh-1)-u0(2:xmesh-1,2:ymesh-1)+c*c*dt*dt*(diff(u1(:,2:ymesh-1),2,1)/(dx*dx)+diff(u1(2:xmesh-1,:),2,2)/(dy*dy));
    u0=u1;
    u1=u2;
    t = t+dt;
    u1(:,1)=0.;
    u1(:,ymesh)=0.;
    u1(1,:)=0.;
    u1(xmesh,:)=0.;
    
    mesh(X,Y,u1);
    zlim([-1e-5 1e-5]);
    pause(0.01);
end
toc
[X,Y] = meshgrid(x,y);
mesh(X,Y,u1)

実行すると波の反射の様子がきれいにアニメーションみたいに見れます。

41
41
4

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
41
41

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?