LoginSignup
13
4

More than 3 years have passed since last update.

TensorFlow 2.0 で微分方程式の数値計算をした話

Last updated at Posted at 2019-12-25

はじめに

TensorFlow 2.0 はディープラーニングを初めとした機械学習に利用できる非常に強力なフレームワークです。TensorFlow では数値計算のアルゴリズムを「計算グラフ」という形で抽象的に取り扱います。

SQL でいう実行計画っぽいですね。これによって「アルゴリズムの定義」と「その実行」を分離することができそうです。うまく使えば、機械学習だけでなく一般の数値計算に利用する上でも非常に役立つのでは?

と思ってやってみた

今回は Lotka-Volterra 方程式を解いてみましょう! 捕食者と被食者の個体数の変動をモデル化した方程式です。自然界でも見られる「周期解」が再現されることでも有名ですね。

import tensorflow as tf
import matplotlib.pyplot as plt
import timeit

a = 0.2
b = 0.04
c = 0.4
d = 0.05

d_esa = lambda esa, hoshokusha: a*esa - b*esa*hoshokusha
d_hoshokusha = lambda esa, hoshokusha: -c*hoshokusha + d*esa*hoshokusha

def runge_kutta(f, x, y, dt):
    k1 = dt*f(x,y)
    k2 = dt*f(x+0.5*k1, y+0.5*k1)
    k3 = dt*f(x+0.5*k2, y+0.5*k2)
    k4 = dt*f(x+k3, y+k3)
    return (k1+2.0*k2+2.0*k3+k4)/6.0

@tf.function
def lotka_volterra(t_max, t_num, esa_init, hoshokusha_init):
    esa_result = []
    hoshokusha_result = []
    t_result = []

    t = 0.0
    esa = esa_init
    hoshokusha = hoshokusha_init

    esa_result.append(esa)
    hoshokusha_result.append(hoshokusha)
    t_result.append(t)

    dt = t_max / t_num

    while t < t_max:
        t += dt
        esa += runge_kutta(d_esa, esa, hoshokusha, dt)
        hoshokusha += runge_kutta(d_hoshokusha, esa, hoshokusha, dt)

        esa_result.append(esa)
        hoshokusha_result.append(hoshokusha)
        t_result.append(t)

    return esa_result, hoshokusha_result, t_result

# warm up!!!!!!
esa_result, hoshokusha_result, t_result = lotka_volterra(100.0, 2000, 10, 1)
print(timeit.timeit(lambda: lotka_volterra(100.0, 2000, 10, 1), number=1))

plt.plot(t_result, esa_result)
plt.plot(t_result, hoshokusha_result)
plt.show()

tf.function デコレータをつけることで、内部の計算が TensorFlow 上のグラフとして計算されることになります。Tensorflow 2.0 から「Python コードと織り交ぜて書く」こともある程度可能になりました。ありがたいですね。今回の場合は tf.function を取ることで「ピュアな Python コード」としてすぐに実行することが可能です。

なお、私個人の環境で tf.function を付けたほうが4倍程度実行時間が遅くなりました。(CPU)
もちろん今回はオーバーヘッドにしかなってないから、当然ですね(´・ω・`)

おわりに

いかがでしたか?

Lotka-Volterra 方程式程度なら TensorFlow 2.0 を全く使う必要がないことがわかりました!

偏微分方程式など、行列(テンソル)の演算を多用するような、今回よりかなり重たい数値計算であれば TensorFlow が活躍できるかもしれません。

やろうとしたけれど、うまくいかなかったのはサンタさんとの秘密だよ!
良いお年を :dogeza:

13
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
13
4