Qiita Teams that are logged in
You are not logged in to any team

Community
Service
Qiita JobsQiita ZineQiita Blog
7
Help us understand the problem. What are the problem?

More than 3 years have passed since last update.

@yusugomori@github

# JavaScript で RNN（リカレントニューラルネットワーク）

JavaScriptでリカレントニューラルネットワーク（Recurrent Neural Networks: RNN）を実装してみました。

まずは核となる rnn.js から。

``````const math = require('./math');

class RNN {
constructor(nIn, nHidden, nOut, truncatedTime = 3, learningRate = 0.1, activation = math.fn.tanh, rng = Math.random) {
this.nIn = nIn;
this.nHidden = nHidden;
this.nOut = nOut;
this.truncatedTime = truncatedTime;
this.learningRate = learningRate;
this.activation = activation;

// this._activationOutput = (nOut === 1) ? math.fn.sigmoid : math.fn.softmax;

this.U = math.array.uniform(-Math.sqrt(1/nIn), Math.sqrt(1/nIn), rng, [nHidden, nIn]);  // input -> hidden
this.V = math.array.uniform(-Math.sqrt(1/nHidden), Math.sqrt(1/nHidden), rng, [nOut, nHidden]);  // hidden -> output
this.W = math.array.uniform(-Math.sqrt(1/nHidden), Math.sqrt(1/nHidden), rng, [nHidden, nHidden]);  // hidden -> hidden

this.b = math.array.zeros(nHidden);  // hidden bias
this.c = math.array.zeros(nOut);  // output bias
}

// x: number[][]  ( number[time][index] )
forwardProp(x) {
let timeLength = x.length;

let s = math.array.zeros(timeLength, this.nHidden);
let u = math.array.zeros(timeLength, this.nHidden);
let y = math.array.zeros(timeLength, this.nOut);
let v = math.array.zeros(timeLength, this.nOut);

for (let t = 0; t < timeLength; t++) {
let _st = (t === 0) ? math.array.zeros(this.nHidden) : s[t - 1];
s[t] = this.activation(u[t]);

// y[t] = this._activationOutput(this.v[t]);
y[t] = math.fn.linear(v[t]);
}

return {
s: s,
u: u,
y: y,
v: v
};
}

backProp(x, label) {
let dU = math.array.zeros(this.nHidden, this.nIn);
let dV = math.array.zeros(this.nOut, this.nHidden);
let dW = math.array.zeros(this.nHidden, this.nHidden);
let db = math.array.zeros(this.nHidden);
let dc = math.array.zeros(this.nOut);

let timeLength = x.length;
let units = this.forwardProp(x);
let s = units.s;
let u = units.u;
let y = units.y;
let v = units.v;

// let eo = math.mul(math.sub(o, label), this._activationOutput.grad(this.v));
let eo = math.mul(math.sub(y, label), math.fn.linear.grad(v));
let eh = math.array.zeros(timeLength, this.nHidden);

for (let t = timeLength - 1; t >= 0; t--) {

for (let z = 0; z < this.truncatedTime; z++) {
if (t - z < 0) {
break;
}

dU = math.add(dU, math.outer(eh[t - z], x[t - z]));
db = math.add(db, eh[t - z]);

if (t - z - 1 >= 0) {
dW = math.add(dW, math.outer(eh[t - z], s[t - z - 1]));
eh[t - z - 1] = math.mul(math.dot(eh[t - z], this.W), this.activation.grad(u[t - z - 1]));
}
}
}

return {
U: dU,
V: dV,
W: dW,
b: db,
c: dc
}
};
}

sgd(x, label, learningRate) {
learningRate = learningRate || this.learningRate;

}

predict(x) {
let units = this.forwardProp(x);
return units.y;
}
}

module.exports = RNN;
``````

さて、ここで肝心なのが、冒頭にある `math` です。これは Python で言うところの numpy のような挙動を目指すべく、いくつか線形代数計算で必要となるところの実装をまとめたものです。リポジトリ内の math ディレクトリに色々メソッドを書いています（ただし、まだまだWIPなところも多いです）。 似たようなライブラリには math.js がありますが、Matrix Object が個人的に扱いづらく、あくまでもPure Arrayで計算処理を行いたかったので、自分で実装しています。 `math.array.zeros``math.dot`, `math.outer` など、 numpyっぽい書き方で線形代数演算が行えるようになるので、 RNNクラス の各メソッドも、割りと数式通りにスッキリ書くことができています。

また、出力層の活性化関数はsoftmax/sigmoid部分をコメントアウトして、単純な線形活性を用いていますが、これは今回予測したいタスク（後述）に合わせる形となっています。

さて、今回予測するのは、sin波です。こちらの記事でもありますが、手っ取り早くRNNの予測を試すにはうってつけのタスクです。`0...t` の sin波 が与えられたときに、`t+1` のsin波を予測します。 ただし、単純なsin波ではなく、-1 ~ 1 の一様分布に係数0.1をかけたノイズを波に足しています。これで予測を行ったのが下記のコード。

``````const print = require('./utils').print;
const math = require('./math');
const seedrandom = require('seedrandom');
const RNN = require('./rnn');

let rng = seedrandom(1234);

function main() {

const TRAIN_NUM = 30;  // time sequence
const TEST_NUM = 10;

const N_IN = 1;
const N_HIDDEN = 4;
const N_OUT = 1;
const TRUNCATED_TIME = 4;
const LEARNING_RATE = 0.01;
const EPOCHS = 100;

let classifier = new RNN(N_IN, N_HIDDEN, N_OUT, TRUNCATED_TIME, LEARNING_RATE, math.fn.tanh, rng);

for (let epoch = 0; epoch < EPOCHS; epoch++) {
if (epoch !== 0 && epoch % 10 === 0) {
print(`epoch: \${epoch}`);
}

classifier.sgd(_data.x, _data.y);
}

let output = null;
for (let i = 0; i < 50; i++) {
output = classifier.predict(testX);
testX.push(output[output.length - 1]);
}

print('-----');
for (let i = TEST_NUM; i < testX.length - 1; i++) {
print(output[i][0]);
}
print('-----');

// for (let data of output) {
//   print(data[0])
// }
}

let x = [];  // sin wave + noise [0, t]
let y = [];  // t + 1
const TIME_STEP = 0.1;

let noise = () => {
return 0.1 * math.random.uniform(-1, 1, rng);
}

for (let i = 0; i < dataNum + 1; i++) {
let _t = i * TIME_STEP;
let _sin = Math.sin(_t * Math.PI);
x[i] = [_sin + noise()];

if (i !== 0) {
y[i - 1] = x[i];
}
}
x.pop();

return {
x: x,
y: y
};
}

main();
``````

こちらは特段説明が必要なところはないかと思いますが、`main` の最後の部分の出力を可視化してみると、きちんと（ノイズがまあまあ取り除かれた）sin波が描かれるのが分かるかと思います。

``````\$ npm install
\$ node main.js
``````

と実行すると結果が得られますが、せっかくJSで実装しているので、ブラウザで経過が見られるように変えていこうと思っています（他の手法も実装していきたい）。

Why not register and get more from Qiita?
1. We will deliver articles that match you
By following users and tags, you can catch up information on technical fields that you are interested in as a whole
2. you can read useful information later efficiently
By "stocking" the articles you like, you can search right away
7
Help us understand the problem. What are the problem?