LoginSignup
2
1

More than 5 years have passed since last update.

シンプルな計算グラフの保持クラス

Posted at

CalculationRecorder

ソースはこちらです。

以下のようにしてクラスを定義します。

import calculation_recorder as cr

func = cr.CalculationRecorder()

以下のように適当な値をどんどん足してみます。

func = func + 1
func = func + 2
func = func + 3
func = func + 4
func = func + 5
func = func + 6
func = func + 7
func = func + 8
func = func + 9
func = func + 10

funcにはこれまでの計算過程が含まれています。
適当な値を順伝播させてみましょう。

func.forward(0)
55

となります。単純な$0+1+2+3+4+5+6+7+8+9+10=55$です。

次に、逆伝播させてみます。55を逆伝播させると0になりそうです。

func.backward(55)
0

0でした。

用途

画像処理の各種座標変換に使っています。
DLの前処理のクラスに忍ばせておくと、座標変換後の$(x,y)$は座標変換前の座標系でいうとどのあたりなのかなどを悩まなくてよくなります。

例えば、画像を半分に圧縮した上で{$(x,y)|100\leq x<200,100 \leq y<200$}として、更に4倍に引き伸ばしたときの画像の$(x,y)=(20,20)$はもとの画像でいうとどこか。ということを考えます。一つ一つ考えてコードを書いていけばよいですが、このクラスを使うと以下のように簡単にかけます。

from numpy import array
func = cr.CalculationRecorder()
func = func / 2
func = func - 100
func = func * 4

func.backward(20)

x,yは独立なので、numpy.arrayで処理できれば楽ですが、対応していません。

2
1
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
1