1. yusugomori@github

    Posted

    yusugomori@github
Changes in title
+JavaScript で LSTM (Long Short-Term Memory)
Changes in tags
Changes in body
Source | HTML | Preview
@@ -0,0 +1,355 @@
+[RNN](http://qiita.com/yusugomori@github/items/2199bc31e8b240d26907)に引き続き、LSTMをJavaScriptで実装してみました。今回もsin波の予測を行っています。RNNを含むコード全体は[こちら](https://github.com/yusugomori/DeepLearningJS)のリポジトリをご覧ください。また、理論・数式は[こちら](https://micin.jp/feed/developer/articles/lstm00)にまとめていますので、参考にしてみてください。
+
+```js:lstm.js
+const math = require('./math');
+
+
+class LSTM {
+ constructor(nIn, nHidden, nOut, learningRate, activation = math.fn.tanh, rng = Math.random) {
+ this.nIn = nIn;
+ this.nHidden = nHidden;
+ this.nOut = nOut;
+ this.learningRate = learningRate;
+ this.activation = activation;
+
+ this.Wc = math.array.uniform(-Math.sqrt(1/nIn), Math.sqrt(1/nIn), rng, [nHidden, nIn]);
+ this.Wi = math.array.uniform(-Math.sqrt(1/nIn), Math.sqrt(1/nIn), rng, [nHidden, nIn]);
+ this.Wf = math.array.uniform(-Math.sqrt(1/nIn), Math.sqrt(1/nIn), rng, [nHidden, nIn]);
+ this.Wo = math.array.uniform(-Math.sqrt(1/nIn), Math.sqrt(1/nIn), rng, [nHidden, nIn]);
+
+ this.Uc = math.array.uniform(-Math.sqrt(1/nHidden), Math.sqrt(1/nHidden), rng, [nHidden, nHidden]);
+ this.Ui = math.array.uniform(-Math.sqrt(1/nHidden), Math.sqrt(1/nHidden), rng, [nHidden, nHidden]);
+ this.Uf = math.array.uniform(-Math.sqrt(1/nHidden), Math.sqrt(1/nHidden), rng, [nHidden, nHidden]);
+ this.Uo = math.array.uniform(-Math.sqrt(1/nHidden), Math.sqrt(1/nHidden), rng, [nHidden, nHidden]);
+
+ this.bc = math.array.zeros(nHidden);
+ this.bi = math.array.zeros(nHidden);
+ this.bf = math.array.zeros(nHidden);
+ this.bo = math.array.zeros(nHidden);
+
+ this.V = math.array.uniform(-Math.sqrt(1/nHidden), Math.sqrt(1/nHidden), rng, [nOut, nHidden]);
+ this.b = math.array.zeros(nOut);
+ }
+
+ forwardProp(x) {
+ let timeLength = x.length;
+
+ let pre = {
+ a: math.array.zeros(timeLength, this.nHidden),
+ i: math.array.zeros(timeLength, this.nHidden),
+ f: math.array.zeros(timeLength, this.nHidden),
+ o: math.array.zeros(timeLength, this.nHidden),
+ y: math.array.zeros(timeLength, this.nOut)
+ };
+
+ let a = math.array.zeros(timeLength, this.nHidden);
+ let i = math.array.zeros(timeLength, this.nHidden);
+ let f = math.array.zeros(timeLength, this.nHidden);
+ let o = math.array.zeros(timeLength, this.nHidden);
+
+ let c = math.array.zeros(timeLength, this.nHidden);
+ let h = math.array.zeros(timeLength, this.nHidden);
+
+ let y = math.array.zeros(timeLength, this.nOut);
+
+ for (let t = 0; t < timeLength; t++) {
+ let _ht = (t === 0) ? math.array.zeros(this.nHidden) : h[t - 1];
+ pre.a[t] = math.add(math.add(math.dot(this.Wc, x[t]), math.dot(this.Uc, _ht)), this.bc);
+ pre.i[t] = math.add(math.add(math.dot(this.Wi, x[t]), math.dot(this.Ui, _ht)), this.bi);
+ pre.f[t] = math.add(math.add(math.dot(this.Wf, x[t]), math.dot(this.Uf, _ht)), this.bf);
+ pre.o[t] = math.add(math.add(math.dot(this.Wo, x[t]), math.dot(this.Uo, _ht)), this.bo);
+
+ a[t] = this.activation(pre.a[t]);
+ i[t] = math.fn.sigmoid(pre.i[t]);
+ f[t] = math.fn.sigmoid(pre.f[t]);
+ o[t] = math.fn.sigmoid(pre.o[t]);
+
+ if (t !== 0) {
+ c[t] = math.add(math.mul(i[t], a[t]), math.mul(f[t], c[t - 1]));
+ }
+
+ h[t] = math.mul(o[t], this.activation(c[t]));
+
+ pre.y[t] = math.add(math.dot(this.V, h[t]), this.b);
+
+ // y[t] = math.fn.softmax(pre.y[t]);
+ y[t] = math.fn.linear(pre.y[t]);
+ }
+
+ return {
+ a: a,
+ i: i,
+ f: f,
+ o: o,
+ c: c,
+ h: h,
+ y: y,
+ pre: pre
+ };
+ }
+
+ backProp(x, label) {
+ let dV = math.array.zeros(this.nOut, this.nHidden);
+ let db = math.array.zeros(this.nOut);
+
+ let timeLength = x.length;
+ let units = this.forwardProp(x);
+ let a = units.a;
+ let i = units.i;
+ let f = units.f;
+ let o = units.o;
+ let c = units.c;
+ let h = units.h;
+ let y = units.y;
+ let pre = units.pre;
+
+ let W = this._packW();
+ let dW = math.array.zeros(this.nHidden * 4, this.nIn + this.nHidden + 1);
+ let z = math.array.zeros(timeLength, this.nIn + this.nHidden + 1);
+ let s = math.array.zeros(timeLength, this.nHidden * 4);
+
+ let delta = {
+ o: math.sub(y, label),
+ h: math.mul(math.sub(y, label), math.fn.linear.grad(pre.y))
+ // h: math.mul(math.sub(y, label), math.fn.softmax.grad(pre.y))
+ };
+
+ let e = {
+ a: math.array.zeros(timeLength, this.nHidden), // input error of LSTM block
+ i: math.array.zeros(timeLength, this.nHidden), // input gate
+ f: math.array.zeros(timeLength, this.nHidden), // forget gate
+ o: math.array.zeros(timeLength, this.nHidden), // output gate
+ c: math.array.zeros(timeLength, this.nHidden), // cell
+ h: math.array.zeros(timeLength, this.nHidden), // output error of LSTM block
+ pre: {
+ a: math.array.zeros(timeLength, this.nHidden),
+ i: math.array.zeros(timeLength, this.nHidden),
+ f: math.array.zeros(timeLength, this.nHidden),
+ o: math.array.zeros(timeLength, this.nHidden),
+ },
+
+ s: math.array.zeros(timeLength, this.nHidden * 4), // pack errors above
+ z: math.array.zeros(timeLength, this.nIn + this.nHidden + 1) // error for z: [x(t), h(t-1), 1] below
+ };
+
+ for (let t = timeLength - 1; t >= 0; t--) {
+ dV = math.add(dV, math.outer(delta.o[t], h[t]));
+ db = math.add(db, delta.o[t]);
+
+ e.h[t] = math.dot(delta.h[t], this.V);
+
+ e.o[t] = math.mul(e.h[t], this.activation(c[t]));
+ let _ec = math.mul(math.mul(e.h[t], o[t]), this.activation.grad(c[t]));
+ e.c[t] = math.add(e.c[t], _ec);
+
+ if (t !== 0) {
+ e.c[t - 1] = math.mul(e.c[t], f[t]);
+ e.f[t] = math.mul(e.c[t], c[t - 1]);
+ }
+ e.i[t] = math.mul(e.c[t], a[t]);
+
+ e.a[t] = math.mul(e.c[t], i[t]);
+
+ e.pre.a[t] = math.mul(e.a[t], this.activation.grad(pre.a[t]));
+ e.pre.i[t] = math.mul(e.i[t], math.fn.sigmoid.grad(pre.i[t]));
+ e.pre.f[t] = math.mul(e.f[t], math.fn.sigmoid.grad(pre.f[t]));
+ e.pre.o[t] = math.mul(e.o[t], math.fn.sigmoid.grad(pre.o[t]));
+
+ let _ht = (t === 0) ? math.array.zeros(this.nHidden) : h[t - 1];
+ e.s[t] = e.pre.a[t].concat(e.pre.i[t], e.pre.f[t], e.pre.o[t]);
+ z[t] = x[t].concat(_ht, [1]);
+
+ s[t] = math.dot(z[t], math.T(W));
+ e.z[t] = math.dot(e.s[t], W);
+
+ dW = math.add(dW, math.outer(e.s[t], z[t]));
+ }
+
+ return {
+ grad: {
+ W: dW,
+ V: dV,
+ b: db
+ }
+ };
+ }
+
+ sgd(x, label, learningRate) {
+ learningRate = learningRate || this.learningRate;
+ let grad = this.backProp(x, label).grad;
+ let dW = grad.W;
+ let dV = grad.V;
+ let db = grad.b;
+
+ let index = {
+ row: {
+ c: this.nHidden,
+ i: this.nHidden * 2,
+ f: this.nHidden * 3,
+ o: this.nHidden * 4
+ },
+ col: {
+ W: this.nIn,
+ U: this.nIn + this.nHidden,
+ b: this.nIn + this.nHidden + 1
+ }
+ };
+
+ let dWc = math.subset(dW, [[ 0, index.row.c], [0, index.col.W]]);
+ let dWi = math.subset(dW, [[index.row.c, index.row.i], [0, index.col.W]]);
+ let dWf = math.subset(dW, [[index.row.i, index.row.f], [0, index.col.W]]);
+ let dWo = math.subset(dW, [[index.row.f, index.row.o], [0, index.col.W]]);
+
+ let dUc = math.subset(dW, [[ 0, index.row.c], [index.col.W, index.col.U]]);
+ let dUi = math.subset(dW, [[index.row.c, index.row.i], [index.col.W, index.col.U]]);
+ let dUf = math.subset(dW, [[index.row.i, index.row.f], [index.col.W, index.col.U]]);
+ let dUo = math.subset(dW, [[index.row.f, index.row.o], [index.col.W, index.col.U]]);
+
+ let dbc = math.flatten(math.subset(dW, [[ 0, index.row.c], [index.col.U, index.col.b]]));
+ let dbi = math.flatten(math.subset(dW, [[index.row.c, index.row.i], [index.col.U, index.col.b]]));
+ let dbf = math.flatten(math.subset(dW, [[index.row.i, index.row.f], [index.col.U, index.col.b]]));
+ let dbo = math.flatten(math.subset(dW, [[index.row.f, index.row.o], [index.col.U, index.col.b]]));
+
+ this.Wc = math.sub(this.Wc, math.mul(learningRate, dWc));
+ this.Wi = math.sub(this.Wi, math.mul(learningRate, dWi));
+ this.Wf = math.sub(this.Wf, math.mul(learningRate, dWf));
+ this.Wo = math.sub(this.Wo, math.mul(learningRate, dWo));
+
+ this.Uc = math.sub(this.Uc, math.mul(learningRate, dUc));
+ this.Ui = math.sub(this.Ui, math.mul(learningRate, dUi));
+ this.Uf = math.sub(this.Uf, math.mul(learningRate, dUf));
+ this.Uo = math.sub(this.Uo, math.mul(learningRate, dUo));
+
+ this.bc = math.sub(this.bc, math.mul(learningRate, dbc));
+ this.bi = math.sub(this.bi, math.mul(learningRate, dbi));
+ this.bf = math.sub(this.bf, math.mul(learningRate, dbf));
+ this.bo = math.sub(this.bo, math.mul(learningRate, dbo));
+
+ this.V = math.sub(this.V, math.mul(learningRate / x.length, dV));
+ this.b = math.sub(this.b, math.mul(learningRate / x.length, db));
+ }
+
+ _packW() {
+ let W = math.array.zeros(this.nHidden * 4, this.nIn + this.nHidden + 1);
+
+ let subset = (index, w, u, b) => {
+ let i = index * this.nHidden;
+ for (let j = 0; j < this.nHidden; j++) {
+ W[i + j] = w[j].concat(u[j], b[j]);
+ }
+ };
+
+ for (let index of [0, 1, 2, 3]) {
+ switch (index) {
+ case 0:
+ subset(index, this.Wc, this.Uc, this.bc);
+ break;
+ case 1:
+ subset(index, this.Wi, this.Ui, this.bi);
+ break;
+ case 2:
+ subset(index, this.Wf, this.Uf, this.bf);
+ break;
+ case 3:
+ subset(index, this.Wo, this.Uo, this.bo);
+ break;
+ default:
+ break;
+ }
+ }
+ return W;
+ }
+
+ predict(x) {
+ let units = this.forwardProp(x);
+ return units.y;
+ }
+}
+
+module.exports = LSTM;
+```
+
+最もシンプルなLSTMの実装ということで、LSTMブロックに続く出力層部分も同時に実装しているなど、少し簡略的に書いているところもありますが、全体の流れとしては数式に沿って書いてあります。
+
+この`LSTM`クラスを利用して、sin波の予測を行うのがこちら。
+
+```js:main.js
+const seedrandom = require('seedrandom');
+const print = require('./utils').print;
+const math = require('./math');
+const LSTM = require('./lstm');
+
+let rng = seedrandom(1234);
+
+function main() {
+ const TRAIN_NUM = 30; // time sequence
+ const TEST_NUM = 10;
+
+ const N_IN = 1;
+ const N_HIDDEN = 8;
+ const N_OUT = 1;
+ const LEARNING_RATE = 0.1;
+ const EPOCHS = 200;
+
+ let classifier = new LSTM(N_IN, N_HIDDEN, N_OUT, LEARNING_RATE, math.fn.tanh, rng);
+
+ for (let epoch = 0; epoch < EPOCHS; epoch++) {
+ if (epoch !== 0 && epoch % 10 === 0) {
+ print(`epoch: ${epoch}`);
+ }
+ let _data = loadData(TRAIN_NUM);
+
+ classifier.sgd(_data.x, _data.y);
+ }
+
+ let testX = loadData(TEST_NUM).x;
+ let output = null;
+ for (let i = 0; i < 100; i++) {
+ output = classifier.predict(testX);
+ testX.push(output[output.length - 1]);
+ }
+
+ print('-----');
+ for (let i = TEST_NUM; i < testX.length - 1; i++) {
+ print(output[i][0]);
+ }
+ print('-----');
+}
+
+function loadData(dataNum) {
+ let x = []; // sin wave + noise [0, t]
+ let y = []; // t + 1
+ const TIME_STEP = 0.1;
+
+ let noise = () => {
+ return 0.1 * math.random.uniform(-1, 1, rng);
+ }
+
+ for (let i = 0; i < dataNum + 1; i++) {
+ let _t = i * TIME_STEP;
+ let _sin = Math.sin(_t * Math.PI);
+ x[i] = [_sin + noise()];
+
+ if (i !== 0) {
+ y[i - 1] = x[i];
+ }
+ }
+ x.pop();
+
+ return {
+ x: x,
+ y: y
+ };
+}
+
+main();
+```
+
+実験用に `seedrandom` を用いていますが、 `Math.random` をそのまま用いても問題ありません。
+
+```
+$ npm install
+$ node main.js
+```
+
+により、sin波の予測が行われているのが確認できるかと思います。