8
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

JavaScript で RNN(リカレントニューラルネットワーク)

Last updated at Posted at 2016-12-06

JavaScriptでリカレントニューラルネットワーク(Recurrent Neural Networks: RNN)を実装してみました。
完成したコードはGitHubのこちらのリポジトリにまとめてあります。 また、数式はこちらにまとめましたので、理論部分についてはそちらを参考にしてください。

まずは核となる rnn.js から。

const math = require('./math');

class RNN {
  constructor(nIn, nHidden, nOut, truncatedTime = 3, learningRate = 0.1, activation = math.fn.tanh, rng = Math.random) {
    this.nIn = nIn;
    this.nHidden = nHidden;
    this.nOut = nOut;
    this.truncatedTime = truncatedTime;
    this.learningRate = learningRate;
    this.activation = activation;

    // this._activationOutput = (nOut === 1) ? math.fn.sigmoid : math.fn.softmax;

    this.U = math.array.uniform(-Math.sqrt(1/nIn), Math.sqrt(1/nIn), rng, [nHidden, nIn]);  // input -> hidden
    this.V = math.array.uniform(-Math.sqrt(1/nHidden), Math.sqrt(1/nHidden), rng, [nOut, nHidden]);  // hidden -> output
    this.W = math.array.uniform(-Math.sqrt(1/nHidden), Math.sqrt(1/nHidden), rng, [nHidden, nHidden]);  // hidden -> hidden

    this.b = math.array.zeros(nHidden);  // hidden bias
    this.c = math.array.zeros(nOut);  // output bias
  }

  // x: number[][]  ( number[time][index] )
  forwardProp(x) {
    let timeLength = x.length;

    let s = math.array.zeros(timeLength, this.nHidden);
    let u = math.array.zeros(timeLength, this.nHidden);
    let y = math.array.zeros(timeLength, this.nOut);
    let v = math.array.zeros(timeLength, this.nOut);

    for (let t = 0; t < timeLength; t++) {
      let _st = (t === 0) ? math.array.zeros(this.nHidden) : s[t - 1];
      u[t] = math.add(math.add(math.dot(this.U, x[t]), math.dot(this.W, _st)), this.b);
      s[t] = this.activation(u[t]);

      v[t] = math.add(math.dot(this.V, s[t]), this.c)
      // y[t] = this._activationOutput(this.v[t]);
      y[t] = math.fn.linear(v[t]);
    }

    return {
      s: s,
      u: u,
      y: y,
      v: v
    };
  }

  backProp(x, label) {
    let dU = math.array.zeros(this.nHidden, this.nIn);
    let dV = math.array.zeros(this.nOut, this.nHidden);
    let dW = math.array.zeros(this.nHidden, this.nHidden);
    let db = math.array.zeros(this.nHidden);
    let dc = math.array.zeros(this.nOut);

    let timeLength = x.length;
    let units = this.forwardProp(x);
    let s = units.s;
    let u = units.u;
    let y = units.y;
    let v = units.v;

    // let eo = math.mul(math.sub(o, label), this._activationOutput.grad(this.v));
    let eo = math.mul(math.sub(y, label), math.fn.linear.grad(v));
    let eh = math.array.zeros(timeLength, this.nHidden);

    for (let t = timeLength - 1; t >= 0; t--) {
      dV = math.add(dV, math.outer(eo[t], s[t]));
      dc = math.add(dc, eo[t]);
      eh[t] = math.mul(math.dot(eo[t], this.V), this.activation.grad(u[t]));

      for (let z = 0; z < this.truncatedTime; z++) {
        if (t - z < 0) {
          break;
        }

        dU = math.add(dU, math.outer(eh[t - z], x[t - z]));
        db = math.add(db, eh[t - z]);

        if (t - z - 1 >= 0) {
          dW = math.add(dW, math.outer(eh[t - z], s[t - z - 1]));
          eh[t - z - 1] = math.mul(math.dot(eh[t - z], this.W), this.activation.grad(u[t - z - 1]));
        }
      }
    }

    return {
      grad: {
        U: dU,
        V: dV,
        W: dW,
        b: db,
        c: dc
      }
    };
  }

  sgd(x, label, learningRate) {
    learningRate = learningRate || this.learningRate;
    let grad = this.backProp(x, label).grad;

    this.U = math.sub(this.U, math.mul(learningRate, grad.U));
    this.V = math.sub(this.V, math.mul(learningRate, grad.V));
    this.W = math.sub(this.W, math.mul(learningRate, grad.W));
    this.b = math.sub(this.b, math.mul(learningRate, grad.b));
    this.c = math.sub(this.c, math.mul(learningRate, grad.c));
  }

  predict(x) {
    let units = this.forwardProp(x);
    return units.y;
  }
}


module.exports = RNN;

さて、ここで肝心なのが、冒頭にある math です。これは Python で言うところの numpy のような挙動を目指すべく、いくつか線形代数計算で必要となるところの実装をまとめたものです。リポジトリ内の math ディレクトリに色々メソッドを書いています(ただし、まだまだWIPなところも多いです)。 似たようなライブラリには math.js がありますが、Matrix Object が個人的に扱いづらく、あくまでもPure Arrayで計算処理を行いたかったので、自分で実装しています。 math.array.zerosmath.dot, math.outer など、 numpyっぽい書き方で線形代数演算が行えるようになるので、 RNNクラス の各メソッドも、割りと数式通りにスッキリ書くことができています。

また、出力層の活性化関数はsoftmax/sigmoid部分をコメントアウトして、単純な線形活性を用いていますが、これは今回予測したいタスク(後述)に合わせる形となっています。

さて、今回予測するのは、sin波です。こちらの記事でもありますが、手っ取り早くRNNの予測を試すにはうってつけのタスクです。0...t の sin波 が与えられたときに、t+1 のsin波を予測します。 ただし、単純なsin波ではなく、-1 ~ 1 の一様分布に係数0.1をかけたノイズを波に足しています。これで予測を行ったのが下記のコード。

const print = require('./utils').print;
const math = require('./math');
const seedrandom = require('seedrandom');
const RNN = require('./rnn');

let rng = seedrandom(1234);


function main() {

  const TRAIN_NUM = 30;  // time sequence
  const TEST_NUM = 10;

  const N_IN = 1;
  const N_HIDDEN = 4;
  const N_OUT = 1;
  const TRUNCATED_TIME = 4;
  const LEARNING_RATE = 0.01;
  const EPOCHS = 100;

  let classifier = new RNN(N_IN, N_HIDDEN, N_OUT, TRUNCATED_TIME, LEARNING_RATE, math.fn.tanh, rng);

  for (let epoch = 0; epoch < EPOCHS; epoch++) {
    if (epoch !== 0 && epoch % 10 === 0) {
      print(`epoch: ${epoch}`);
    }
    let _data = loadData(TRAIN_NUM);

    classifier.sgd(_data.x, _data.y);
  }

  let testX = loadData(TEST_NUM).x;
  let output = null;
  for (let i = 0; i < 50; i++) {
    output = classifier.predict(testX);
    testX.push(output[output.length - 1]);
  }

  print('-----');
  for (let i = TEST_NUM; i < testX.length - 1; i++) {
    print(output[i][0]);
  }
  print('-----');

  // for (let data of output) {
  //   print(data[0])
  // }
}

function loadData(dataNum) {
  let x = [];  // sin wave + noise [0, t]
  let y = [];  // t + 1
  const TIME_STEP = 0.1;

  let noise = () => {
    return 0.1 * math.random.uniform(-1, 1, rng);
  }

  for (let i = 0; i < dataNum + 1; i++) {
    let _t = i * TIME_STEP;
    let _sin = Math.sin(_t * Math.PI);
    x[i] = [_sin + noise()];

    if (i !== 0) {
      y[i - 1] = x[i];
    }
  }
  x.pop();

  return {
    x: x,
    y: y
  };
}

main();

こちらは特段説明が必要なところはないかと思いますが、main の最後の部分の出力を可視化してみると、きちんと(ノイズがまあまあ取り除かれた)sin波が描かれるのが分かるかと思います。

今回はとても単純なRNNをJavaScriptで実装しました。

$ npm install
$ node main.js

と実行すると結果が得られますが、せっかくJSで実装しているので、ブラウザで経過が見られるように変えていこうと思っています(他の手法も実装していきたい)。

8
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?