16
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

Rustその2Advent Calendar 2019

Day 12

自動微分 in Rust!!!!

Last updated at Posted at 2019-12-12

この記事は Rustその2 Advent Calendar 2019 12/12 の記事です.

広島大学工学部4年生荒木勇登です!
来年から福岡でITエンジニアします!
いみこというハンドルネームで Twitter をやっていますのでそちらもご確認ください。

#自動微分とはなにか
wikiより

自動微分(じどうびぶん、アルゴリズム的微分とも)とは、プログラムで定義された関数を解析し、偏導関数の値を計算するプログラムを導出する技術である。

ざっくりいうと数値微分ではない方法で導関数の値を求める方法だと思っていただければ良いかと思います。
自動微分にはフォワードモードとリバースモードという種類があります。
今回は、フォワードモードを演算子オーバーロードで実装することを目標とします。(参考:自動微分を実装して理解する(前編)

#数学的な準備と実装の解説
上のことを実装するために二重数という概念を導入します。
二重数とは以下のような概念です。

二重数(にじゅうすう、英: dual numbers)は、実数の全体に実数ではない新しい元 ε で複零性 ε2 = 0 を満たすものを添加して得られる実数の拡張概念である。
wikiより

要は、複素数は二乗すると−1になる元iを導入しましたが、 二重数は二乗すると0になるεを導入します、ということです。

これをRustの構造体で表現すると以下のようになります。(実装はこちらのブログを参考にさせていただいております。)

DualNumberStruct.rs
pub struct Dual {
    var: f64,
    dual: f64,
}

さて、次に二重数の四則演算を確認し、その結果を演算子オーバーロードで実装しましょう。
(参考:演算子のオーバーロード

##足し算と引き算

これは簡単です。

(a + bε) + (c + dε) = (a + c) + (b + d)ε\\

(a + bε) - (c + dε) = (a - c) + (b - d)ε


上記の結果より以下のように実装します

AddAndSunOverload.rs
impl Add for Dual {
    type Output = Dual;
    //例えば、X + Y でXがself、Yがr    
    fn add(self, r: Dual) -> Dual {
        //   (self_var + self_dual*ε) + (r_var + r_dual*ε)
        // = (self_var + r_var)      + (self_dual + r_dual)*ε
        Dual {
            var: self.var + r.var,
            dual: self.dual + r.dual
        }
    }
}

impl Sub for Dual {
    type Output = Dual;
    fn sub(self, r: Dual) -> Dual {
        //   (self_var + self_dual*ε) - (r_var + r_dual*ε)
        // = (self_var - r_var)      + (self_dual - r_dual)*ε
        Dual {
            var: self.var - r.var,
            dual: self.dual - r.dual
        }
    }
}

##掛け算

(a + bε) * (c + dε) = (a + c) + (bc + ad)ε + bdε ^2

ε ^2が初めて出てきました。これは二乗数の定義より0となります。
よって計算結果は以下のようになります。

(a + bε) * (c + dε) = (a + c) + (bc + ad)ε 

これを演算子オーバーロードで実装すると以下のようになります。

MulOverload.rs
impl Mul for Dual {
    type Output = Dual;
    fn mul(self,r:Dual) -> Dual {
        //   (self_var + self_dual*ε) * (r_var + r_dual*ε)
        // = (self_var + r_var) + (self_dual*r_var + self_var*r_dual)ε
        Dual {
            var: self.var * r.var,
            dual: self.dual * r.var + self.var * r.dual
        }
    }
}

##割り算

(a + bε) / (c + dε) = (a / c) + ((bc - ad)/ c^2 )ε

この式はc = 0 のとき成り立ちませんが、今回は発生しないとして無視します。

これを演算子オーバーロードで実装すると以下のようになります。

MulOverload.rs
impl Div for Dual {
    type Output = Dual;
    fn div(self, r: Dual) -> Dual {
        //   (self_var + self_dual*ε) / (r_var + r_dual*ε)
        // = (self_var / r_var) + (r_dual*self_var/r_var^2)ε
        Dual {
            var: self.var / r.var,
            eps: self.dual/r.var - r.dual*self.var/r.var/r.var
        }
    }
}

#計算例
さて、これらを使って微分を計算する方法を説明します。

f(x) = x^3 + 2x

の x = 2 の導関数の値を求めます。答えは

f'(x) = 3x^2 + 2\\
f'(2) = 14

です。
どのように二重数を使って導関数の値を計算するのかというと関数f(x)に2 + εを代入し、その計算結果のεの係数が求めたい導関数の値となります。実際に計算してみましょう。

f(2 + ε) = (2 + ε)^3 + 2 * (2 + ε)\\
          = 2^3 + 3*2^2*ε + 3*2*ε^2 + ε^3 + 2 + 2*ε \\
= 8 + 12*ε + 6 + 2*ε\\
= 12 + 14ε

Rustで書くとこんな感じ。

example.rs
use std::ops::{Add,Mul};
  
#[derive(Debug,Copy,Clone)]
pub struct Dual {
    var: f64,
    dual: f64,
}

impl Add for Dual {
    type Output = Dual;
    fn add(self, r: Dual) -> Dual {
        Dual {
            var: self.var + r.var,
            dual: self.dual + r.dual
        }
    }
}
impl Mul for Dual {
    type Output = Dual;
    fn mul(self,r:Dual) -> Dual {
        Dual {
            var: self.var * r.var,
            dual: self.dual * r.var + self.var * r.dual
        }
    }
}
fn func(x: Dual) -> Dual {
    let a : Dual = Dual{var: 2f64,dual: 0f64};
    x * x * x + a * x
}
fn main(){
    let x:  Dual = Dual{var: 2f64,dual: 1f64};
    println!("{:?}",func(x));//Dual { var: 12.0, dual: 14.0 }
}

ということでεの係数が求めたい導関数の値14となりました!!!!すごい!!!!!不思議ですね!!!!!!どうして求められるのでしょうか!!!!!

これは四則演算の、特に割り算や掛け算などを注意してみていただきたいのですが、εの係数が微分の公式と同じ形をしています。

(a + bε) * (c + dε) = (a + c) + (bc + ad)ε \\
(f *g)' = f'*g + f*g'\\
(a + bε) / (c + dε) = (a / c) + ((bc - ad)/ c^2 )ε\\
(f/g )'= (f'g - fg'/g^2)

a = f,b = f',c = g,d = g'を代入すること一致することが確認できるかと思います。(足し算と引き算の場合も一致します。)この性質があることにより、εの係数が導関数の値となります。また同様に、四則演算以外の三角関数などでも二重数に拡張するとき、εの係数に元の関数の微分が来るように定義することで同様に導関数の値を求めることができます。例えば、以下のように

sinOverload
impl Dual {
    fn sin(self) -> Dual {
        //f (x)     = sinx
        //f*(x + £) = f(x) + f'(x)*£
        //f*(x + £) = sinx +  cosx*£
        Dual {
            var: self.var.sin(),
            eps: self.dual*self.var.cos()
        }
    }
}
fn main(){
    let x:  Dual = Dual{var: 0f64,dual: 1f64};
    let pi: Dual = Dual{var: 3.14,dual: 1f64};
    println!("{}",x.sin().dual);//1
    println!("{}",pi.sin().dual);//-0.9999987317275395
}

#参考文献
wikipedia 自動微分
RustでForward自動微分を実装してみた
なぜ双対数(二重数)で微分を求められるのか・・・?
自動微分を実装して理解する(前編)
演算子のオーバーロード

16
10
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
16
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?