1. yusugomori@github

    Posted

    yusugomori@github
Changes in title
+JavaScript でリカレントニューラルネットワーク
Changes in tags
Changes in body
Source | HTML | Preview
@@ -0,0 +1,217 @@
+JavaScriptでリカレントニューラルネットワーク(Recurrent Neural Networks: RNN)を実装してみました。
+完成したコードはGitHubの[こちら](https://github.com/yusugomori/rnn.js)のリポジトリにまとめてあります。 また、数式は[こちら](https://micin.jp/feed/developer/articles/rnn000)にまとめましたので、理論部分についてはそちらを参考にしてください。
+
+まずは核となる rnn.js から。
+
+```js
+const math = require('./math');
+
+class RNN {
+ constructor(nIn, nHidden, nOut, truncatedTime = 3, learningRate = 0.1, activation = math.fn.tanh, rng = Math.random) {
+ this.nIn = nIn;
+ this.nHidden = nHidden;
+ this.nOut = nOut;
+ this.truncatedTime = truncatedTime;
+ this.learningRate = learningRate;
+ this.activation = activation;
+
+ // this._activationOutput = (nOut === 1) ? math.fn.sigmoid : math.fn.softmax;
+
+ this.U = math.array.uniform(-Math.sqrt(1/nIn), Math.sqrt(1/nIn), rng, [nHidden, nIn]); // input -> hidden
+ this.V = math.array.uniform(-Math.sqrt(1/nHidden), Math.sqrt(1/nHidden), rng, [nOut, nHidden]); // hidden -> output
+ this.W = math.array.uniform(-Math.sqrt(1/nHidden), Math.sqrt(1/nHidden), rng, [nHidden, nHidden]); // hidden -> hidden
+
+ this.b = math.array.zeros(nHidden); // hidden bias
+ this.c = math.array.zeros(nOut); // output bias
+ }
+
+ // x: number[][] ( number[time][index] )
+ forwardProp(x) {
+ let timeLength = x.length;
+
+ let s = math.array.zeros(timeLength, this.nHidden);
+ let u = math.array.zeros(timeLength, this.nHidden);
+ let y = math.array.zeros(timeLength, this.nOut);
+ let v = math.array.zeros(timeLength, this.nOut);
+
+ for (let t = 0; t < timeLength; t++) {
+ let _st = (t === 0) ? math.array.zeros(this.nHidden) : s[t - 1];
+ u[t] = math.add(math.add(math.dot(this.U, x[t]), math.dot(this.W, _st)), this.b);
+ s[t] = this.activation(u[t]);
+
+ v[t] = math.add(math.dot(this.V, s[t]), this.c)
+ // y[t] = this._activationOutput(this.v[t]);
+ y[t] = math.fn.linear(v[t]);
+ }
+
+ return {
+ s: s,
+ u: u,
+ y: y,
+ v: v
+ };
+ }
+
+ backProp(x, label) {
+ let dU = math.array.zeros(this.nHidden, this.nIn);
+ let dV = math.array.zeros(this.nOut, this.nHidden);
+ let dW = math.array.zeros(this.nHidden, this.nHidden);
+ let db = math.array.zeros(this.nHidden);
+ let dc = math.array.zeros(this.nOut);
+
+ let timeLength = x.length;
+ let units = this.forwardProp(x);
+ let s = units.s;
+ let u = units.u;
+ let y = units.y;
+ let v = units.v;
+
+ // let eo = math.mul(math.sub(o, label), this._activationOutput.grad(this.v));
+ let eo = math.mul(math.sub(y, label), math.fn.linear.grad(v));
+ let eh = math.array.zeros(timeLength, this.nHidden);
+
+ for (let t = timeLength - 1; t >= 0; t--) {
+ dV = math.add(dV, math.outer(eo[t], s[t]));
+ dc = math.add(dc, eo[t]);
+ eh[t] = math.mul(math.dot(eo[t], this.V), this.activation.grad(u[t]));
+
+ for (let z = 0; z < this.truncatedTime; z++) {
+ if (t - z < 0) {
+ break;
+ }
+
+ dU = math.add(dU, math.outer(eh[t - z], x[t - z]));
+ db = math.add(db, eh[t - z]);
+
+ if (t - z - 1 >= 0) {
+ dW = math.add(dW, math.outer(eh[t - z], s[t - z - 1]));
+ eh[t - z - 1] = math.mul(math.dot(eh[t - z], this.W), this.activation.grad(u[t - z - 1]));
+ }
+ }
+ }
+
+ return {
+ grad: {
+ U: dU,
+ V: dV,
+ W: dW,
+ b: db,
+ c: dc
+ }
+ };
+ }
+
+ sgd(x, label, learningRate) {
+ learningRate = learningRate || this.learningRate;
+ let grad = this.backProp(x, label).grad;
+
+ this.U = math.sub(this.U, math.mul(learningRate, grad.U));
+ this.V = math.sub(this.V, math.mul(learningRate, grad.V));
+ this.W = math.sub(this.W, math.mul(learningRate, grad.W));
+ }
+
+ predict(x) {
+ let units = this.forwardProp(x);
+ return units.y;
+ }
+}
+
+
+module.exports = RNN;
+```
+
+さて、ここで肝心なのが、冒頭にある `math` です。これは Python で言うところの numpy のような挙動を目指すべく、いくつか線形代数計算で必要となるところの実装をまとめたものです。リポジトリ内の math ディレクトリに色々メソッドを書いています(ただし、まだまだWIPなところも多いです)。 似たようなライブラリには math.js がありますが、Matrix Object が個人的に扱いづらく、あくまでもPure Arrayで計算処理を行いたかったので、自分で実装しています。 `math.array.zeros` や `math.dot`, `math.outer` など、 numpyっぽい書き方で線形代数演算が行えるようになるので、 RNNクラス の各メソッドも、割りと数式通りにスッキリ書くことができています。
+
+また、出力層の活性化関数はsoftmax/sigmoid部分をコメントアウトして、単純な線形活性を用いていますが、これは今回予測したいタスク(後述)に合わせる形となっています。
+
+さて、今回予測するのは、sin波です。[こちらの記事](http://qiita.com/yuyakato/items/ab38064ca215e8750865)でもありますが、手っ取り早くRNNの予測を試すにはうってつけのタスクです。`0...t` の sin波 が与えられたときに、`t+1` のsin波を予測します。 ただし、単純なsin波ではなく、-1 ~ 1 の一様分布に係数0.1をかけたノイズを波に足しています。これで予測を行ったのが下記のコード。
+
+```js
+const print = require('./utils').print;
+const math = require('./math');
+const seedrandom = require('seedrandom');
+const RNN = require('./rnn');
+
+let rng = seedrandom(1234);
+
+
+function main() {
+
+ const TRAIN_NUM = 30; // time sequence
+ const TEST_NUM = 10;
+
+ const N_IN = 1;
+ const N_HIDDEN = 4;
+ const N_OUT = 1;
+ const TRUNCATED_TIME = 3;
+ const LEARNING_RATE = 0.05;
+ const EPOCHS = 50;
+
+ let classifier = new RNN(N_IN, N_HIDDEN, N_OUT, TRUNCATED_TIME, LEARNING_RATE, math.fn.tanh, rng);
+
+ for (let epoch = 0; epoch < EPOCHS; epoch++) {
+ if (epoch !== 0 && epoch % 10 === 0) {
+ print(`epoch: ${epoch}`);
+ }
+ let _data = loadData(TRAIN_NUM);
+
+ classifier.sgd(_data.x, _data.y);
+ }
+
+ let testX = loadData(TEST_NUM).x;
+ let output = null;
+ for (let i = 0; i < 50; i++) {
+ output = classifier.predict(testX);
+ testX.push(output[output.length - 1]);
+ }
+
+ print('-----');
+ for (let i = TEST_NUM; i < testX.length - 1; i++) {
+ print(output[i][0]);
+ }
+ print('-----');
+
+ // for (let data of output) {
+ // print(data[0])
+ // }
+}
+
+function loadData(dataNum) {
+ let x = []; // sin wave + noise [0, t]
+ let y = []; // t + 1
+ const TIME_STEP = 0.1;
+
+ let noise = () => {
+ return 0.1 * math.random.uniform(-1, 1, rng);
+ }
+
+ for (let i = 0; i < dataNum + 1; i++) {
+ let _t = i * TIME_STEP;
+ let _sin = Math.sin(_t * Math.PI);
+ x[i] = [_sin + noise()];
+
+ if (i !== 0) {
+ y[i - 1] = x[i];
+ }
+ }
+ x.pop();
+
+ return {
+ x: x,
+ y: y
+ };
+}
+
+main();
+```
+
+こちらは特段説明が必要なところはないかと思いますが、`main` の最後の部分の出力を可視化してみると、きちんと(ノイズがまあまあ取り除かれた)sin波が描かれるのが分かるかと思います。
+
+今回はとても単純なRNNをJavaScriptで実装しました。
+
+```
+$ npm install
+$ node main.js
+```
+
+と実行すると結果が得られますが、せっかくJSで実装しているので、ブラウザで経過が見られるように変えていこうと思っています(他の手法も実装していきたい)。