はじめに
この記事は、TypeScript Advent Calendar 2018 の21日目です。
前日の記事は @__sakito__ さん
明日の記事は @mrsekut さんです。
本記事のテーマは自動微分です。
自動微分とは
さて、自動微分とはなにかからです。Wikipedia曰く、
自動微分(じどうびぶん、アルゴリズム的微分とも)とは、プログラムで定義された関数を解析し、偏導関数の値を計算するプログラムを導出する技術である。自動微分は複雑なプログラムであっても加減乗除などの基本的な算術演算や基本的な関数(指数関数・対数関数・三角関数など)のような基本的な演算の組み合わせで構成されていることを利用し、これらの演算に対して連鎖律を繰り返し適用することによって実現される。自動微分を用いることで偏導関数値を少ない計算量で自動的に求めることができる。
とのことです。
つまり、簡単な四則演算とか初等関数に対して微分の実装をしておき、連鎖律を適用することで微分の計算をしてしまおうというものですね。
二重数
自動微分を実際に行うに当たり、色々な実装が考えられますが、その中で最も簡単なものの一つが二重数です。
この二重数というのは、通常の数値に対して一次元の微小量を付加することで一次微分の値を一緒に取り出すという代物です。
例えば、
f(x) = 3x^2 + 5x - 2
みたいな二次関数があったとしましょう。この関数の導関数は、
f^{\prime}(x) = 6x + 5
ですので、 $x = -2$ における値とその微分値は、
f(-2) = 0\\
f^{\prime}(-2) = -7
となりますね。この微分値を導関数を求めることなく、導出するのに二重数を用いましょう。
代入する値を $x = -2 + \epsilon$ としてみます。そうすると、
f(-2 + \epsilon) = 3\cdot(-2 + \epsilon)^2 + 5\cdot(-2 + \epsilon) - 2 = 0 -7\epsilon + 3\epsilon^{2}
となり、結果を見てみると定数項が値、$\epsilon$ の係数が微分値になっていますね!
このように二次の微小項を$0$として、それぞれの係数を見ることで微分が計算できます。
微分の整理
さて、まずは二重数の二項演算について整理してみましょう。
(a_{0} + b_{0}\epsilon) + (a_{1} + b_{1}\epsilon) = (a_{0} + a_{1}) + (b_{0} + b_{1})\epsilon\\
(a_{0} + b_{0}\epsilon) - (a_{1} + b_{1}\epsilon) = (a_{0} - a_{1}) + (b_{0} - b_{1})\epsilon\\
(a_{0} + b_{0}\epsilon)(a_{1} + b_{1}\epsilon) = a_{0}a_{1} + (a_{0}b_{1} + a_{1}b_{0})\epsilon\\
\dfrac{a_{0} + b_{0}\epsilon}{a_{1} + b_{1}\epsilon} = \dfrac{a_{0}}{a_{1}} + \dfrac{a_{1}b_{0} - a_{0}b_{1}}{a_{1}^{2}}\epsilon\\
次に単項演算です。ある単項演算 $f$ について一次近似を考えると、
f(a + b\epsilon) = f(a) + \left(\dfrac{d}{dx}f(a)\right)b\epsilon
となります。
これらの演算について定義されているクラスを作れば、自動微分されるということです。
やっと実装
長かったですね、やっと実装ですが、その前に開発環境を整えましょう。
今回はjestを導入して、そこに期待する結果を先に書いてから、実装をしてみましょう。
開発環境
なにはともあれ、package.json
です。
{
"private": true,
"scripts": {
"test": "jest"
}
}
次に必要なパッケージのインストールです。
npm i --save-dev typescript jest @types/jest ts-jest
jestとtypescriptの設定ファイルは、
module.exports = {
moduleFileExtensions: ["ts", "js"],
transform: {
"^.+\\.(ts)$": "ts-jest"
},
globals: {
"ts-jest": {
tsConfig: "tsconfig.test.json"
}
},
testMatch: ["**/__tests__/*.+(ts|js)"]
};
{
"compilerOptions": {
"sourceMap": true,
"noImplicitAny": true,
"module": "commonjs",
"target": "es5",
"lib": ["es2018", "dom"],
"moduleResolution": "node",
"removeComments": true,
"strict": true,
"noUnusedLocals": true,
"noUnusedParameters": true,
"noImplicitReturns": true,
"noFallthroughCasesInSwitch": true,
"typeRoots": ["node_modules/@types"]
},
"exclude": ["node_modules"]
}
です。特段気にすべきこともありません。
実装
さて、今度こそ本当に実装です。
テスト
まずはテストです。
import * as UOp from '../Unary';
import * as BOp from '../Binary';
import * as Dual from '../Dual';
describe('integration test', () => {
it('unary operator', () => {
const x = Dual.variable(0.5);
const y = UOp.negate(
UOp.sqrt(UOp.log(UOp.exp(UOp.tan(UOp.sin(UOp.cos(x))))))
);
expect(y.x()).toEqual(-0.9839259443171812);
expect(y.dx()).toEqual(0.3015928005470888);
});
it('linear function', () => {
const x = Dual.variable(4);
const y = BOp.sub(BOp.mul(3, x), 2);
expect(y.x()).toEqual(10);
expect(y.dx()).toEqual(3);
});
it('quadrqtic function', () => {
const x = Dual.variable(-2);
const y = BOp.sub(BOp.add(BOp.mul(3, BOp.mul(x, x)), BOp.mul(5, x)), 2);
expect(y.x()).toEqual(0);
expect(y.dx()).toEqual(-7);
});
});
単項演算と、二項演算はそれぞれをファイルUnary
とBinary
に分けている前提です。
テストとして以下の3つを用意しています。
- 単項演算をガチャガチャっと適用した関数の値と微分値
- 簡単な二項演算のテストとして一次関数の値と微分値
- 二重数同士の演算のテストとして二次関数の値と微分値
これらのテストが通るように実装を進めていきましょう。
二重数
では、二重数の実装です。
export const variable = (x: number): Variable => {
return new Variable(x, 1);
};
export class Variable {
constructor(private _x: number, private _dx: number) {}
public x(): number {
return this._x;
}
public dx(): number {
return this._dx;
}
}
事前準備でも見た通り、実際の値と微分係数を保持する形でクラスを作ります。
コンストラクタの引数にアクセス修飾子をつけていますが、これは、
class Variable {
constructor(x: number, dx: number) {
this._x = x;
this._dx = dx;
}
private _x: number;
private _dx: number;
}
と等価です。
二項演算
次に二項演算です。Variable
だけでなく定数との演算も必要なので
-
number
&number
-
number
&Variable
-
Variable
&number
-
Variable
&Variable
の4種類のオーバーロードが必要です。
TypeScriptでのオーバーロードは、シグネチャをすべて列挙し、これらをすべて含む形でひとつだけ実装を書くという感じになります。不思議。
import * as Dual from './Dual';
export function add(x: number, y: number): number;
export function add(x: Dual.Variable, y: number): Dual.Variable;
export function add(x: number, y: Dual.Variable): Dual.Variable;
export function add(x: Dual.Variable, y: Dual.Variable): Dual.Variable;
export function add(
x: number | Dual.Variable,
y: number | Dual.Variable
): number | Dual.Variable {
if (typeof x === 'number') {
if (typeof y === 'number') {
return x + y;
}
return new Dual.Variable(x + y.x(), y.dx());
}
if (typeof y === 'number') {
return new Dual.Variable(x.x() + y, x.dx());
}
return new Dual.Variable(x.x() + y.x(), x.dx() + y.dx());
}
export function sub(x: number, y: number): number;
export function sub(x: Dual.Variable, y: number): Dual.Variable;
export function sub(x: number, y: Dual.Variable): Dual.Variable;
export function sub(x: Dual.Variable, y: Dual.Variable): Dual.Variable;
export function sub(
x: number | Dual.Variable,
y: number | Dual.Variable
): number | Dual.Variable {
if (typeof x === 'number') {
if (typeof y === 'number') {
return x - y;
}
return new Dual.Variable(x - y.x(), -y.dx());
}
if (typeof y === 'number') {
return new Dual.Variable(x.x() - y, x.dx());
}
return new Dual.Variable(x.x() - y.x(), y.dx() - y.dx());
}
export function mul(x: number, y: number): number;
export function mul(x: Dual.Variable, y: number): Dual.Variable;
export function mul(x: number, y: Dual.Variable): Dual.Variable;
export function mul(x: Dual.Variable, y: Dual.Variable): Dual.Variable;
export function mul(
x: number | Dual.Variable,
y: number | Dual.Variable
): number | Dual.Variable {
if (typeof x === 'number') {
if (typeof y === 'number') {
return x * y;
}
return new Dual.Variable(x * y.x(), x * y.dx());
}
if (typeof y === 'number') {
return new Dual.Variable(x.x() * y, x.dx() * y);
}
return new Dual.Variable(x.x() * y.x(), x.x() * y.dx() + y.x() * x.dx());
}
export function div(x: number, y: number): number;
export function div(x: Dual.Variable, y: number): Dual.Variable;
export function div(x: number, y: Dual.Variable): Dual.Variable;
export function div(x: Dual.Variable, y: Dual.Variable): Dual.Variable;
export function div(
x: number | Dual.Variable,
y: number | Dual.Variable
): number | Dual.Variable {
if (typeof x === 'number') {
if (typeof y === 'number') {
return x / y;
}
return new Dual.Variable(x / y.x(), -x / (y.dx() * y.dx()));
}
if (typeof y === 'number') {
return new Dual.Variable(x.x() / y, x.dx() / y);
}
return new Dual.Variable(x.x() / y.x(), (y.x() * x.dx() - x.x() * y.dx()) / (y.x() * y.x()));
}
単項演算
最後に単項演算です。オーバーロードはnumber
版とVariable
版の2種類です。
import * as Dual from './Dual';
export function negate(x: number): number;
export function negate(x: Dual.Variable): Dual.Variable;
export function negate(x: number | Dual.Variable): number | Dual.Variable {
if (typeof x === 'number') {
return -x;
}
return new Dual.Variable(-x.x(), -x.dx());
}
export function sqrt(x: number): number;
export function sqrt(x: Dual.Variable): Dual.Variable;
export function sqrt(x: number | Dual.Variable): number | Dual.Variable {
if (typeof x === 'number') {
return Math.sqrt(x);
}
return new Dual.Variable(Math.sqrt(x.x()), x.dx() / (2 * Math.sqrt(x.x())));
}
export function sin(x: number): number;
export function sin(x: Dual.Variable): Dual.Variable;
export function sin(x: number | Dual.Variable): number | Dual.Variable {
if (typeof x === 'number') {
return Math.sin(x);
}
return new Dual.Variable(Math.sin(x.x()), x.dx() * Math.cos(x.x()));
}
export function cos(x: number): number;
export function cos(x: Dual.Variable): Dual.Variable;
export function cos(x: number | Dual.Variable): number | Dual.Variable {
if (typeof x === 'number') {
return Math.cos(x);
}
return new Dual.Variable(Math.cos(x.x()), x.dx() * -Math.sin(x.x()));
}
export function tan(x: number): number;
export function tan(x: Dual.Variable): Dual.Variable;
export function tan(x: number | Dual.Variable): number | Dual.Variable {
if (typeof x === 'number') {
return Math.tan(x);
}
return new Dual.Variable(
Math.tan(x.x()),
x.dx() / (Math.cos(x.x()) * Math.cos(x.x()))
);
}
export function exp(x: number): number;
export function exp(x: Dual.Variable): Dual.Variable;
export function exp(x: number | Dual.Variable): number | Dual.Variable {
if (typeof x === 'number') {
return Math.exp(x);
}
return new Dual.Variable(Math.exp(x.x()), x.dx() * Math.exp(x.x()));
}
export function log(x: number): number;
export function log(x: Dual.Variable): Dual.Variable;
export function log(x: number | Dual.Variable): number | Dual.Variable {
if (typeof x === 'number') {
return Math.log(x);
}
return new Dual.Variable(Math.log(x.x()), x.dx() / x.x());
}
完成
これで完成です。npm run test
で、テストをパスするはずです。
おわりに
TypeScriptで自動微分ということで二重数を実装してみましたが、実際にはこの実装だけでは全然使い物になりません。当然、世の中には多変数関数の偏微分をしたいことが多いですし、高階微分したいこともありますが、ここでは対応していません。
また、本当は式情報を型の中に残して、遅延評価するなんてこともやりたかったのですが、時間の都合上叶わず。しかも、TypeScriptにおいては演算子のオーバーロードもできないし中置演算子も定義できないため、可読性という観点で仕上がりがいまいちでした…。
うーん、この題材をTypeScriptでとりあげるにはやや必然性がなかったか。まぁTypeScriptの入門はReactばっかでどうかなと思っていたので、ちょうどよい感じにはなったんじゃないかな!
そーす