18
16

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Extreme learning machineを実装してみた

Posted at

Extreme Learning Machine (ELM) とは?

ELMは特殊な形式のフィードフォワードパーセプトロンです。隠れ層を1層持ちますが、隠れ層の重みはランダムに決定し、出力層の重みを擬似逆行列を使って決定します。イメージ的には、隠れ層はランダムな特徴抽出器を多数作って、出力層で特徴選択をしてやるといった感じです。

ELMは次のような特性を持っています。

  • 解いているのは単なる擬似逆行列なので、高速に解を求めることができる
  • 凸な問題(局所解がない)
  • 任意の関数を近似可能(普通のパーセプトロンと同じ)
  • 任意のactivationが使用可能(微分できる必要なし)

実装

import numpy as np


class ExtremeLearningMachine(object):
    def __init__(self, n_unit, activation=None):
        self._activation = self._sig if activation is None else activation
        self._n_unit = n_unit

    @staticmethod
    def _sig(x):
        return 1. / (1 + np.exp(-x))

    @staticmethod
    def _add_bias(x):
        return np.hstack((x, np.ones((x.shape[0], 1))))

    def fit(self, X, y):
        self.W0 = np.random.random((X.shape[1], self._n_unit))
        z = self._add_bias(self._activation(X.dot(self.W0)))
        self.W1 = np.linalg.lstsq(z, y)[0]

    def transform(self, X):
        if not hasattr(self, 'W0'):
            raise UnboundLocalError('must fit before transform')
        z = self._add_bias(self._activation(X.dot(self.W0)))
        return z.dot(self.W1)

    def fit_transform(self, X, y):
        self.W0 = np.random.random((X.shape[1], self._n_unit))
        z = self._add_bias(self._activation(X.dot(self.W0)))
        self.W1 = np.linalg.lstsq(z, y)[0]
        return z.dot(self.W1)

テスト

とりあえずirisで試してみる。

方法

from sklearn import datasets

iris = datasets.load_iris()
ind = np.random.permutation(len(iris.data))

y = np.zeros((len(iris.target), 3))
y[np.arange(len(y)), iris.target] = 1

acc_train = []
acc_test = []
N = [5, 10, 15, 20, 30, 40, 80, 160]
for n in N:
    elm = ExtremeLearningMachine(n)
    elm.fit(iris.data[ind[:100]], y[ind[:100]])
    acc_train.append(np.average(np.argmax(elm.transform(iris.data[ind[:100]]), axis=1) == iris.target[ind[:100]]))
    acc_test.append(np.average(np.argmax(elm.transform(iris.data[ind[100:]]), axis=1) == iris.target[ind[100:]]))
plt.plot(N, acc_train, c='red', label='train')
plt.plot(N, acc_test, c='blue', label='test')
plt.legend(loc=1)
plt.savefig("result.png")

結果

result.png

結論

  • いじるパラメータが少ないので普通のニューラルネットワークよりもチューニング楽そう
  • 汎化性能が高いとか元論文にかいてあったが、普通にオーバーフィッティングする。
  • fitもtransformもめちゃくちゃ高速なので、いろいろと用途がありそう。
18
16
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
18
16

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?