LoginSignup
15

More than 5 years have passed since last update.

PHPでニューラルネットワークを実装してみる

Last updated at Posted at 2017-12-24

この記事は PHP Advent Calendar 2017 の24日目の記事です。

まえがき

それは2か月前のこと、普段サーバサイドの開発をPHPで行っていた私は、「機械学習ちゃんと勉強したいな〜。参考書とか買って読んでみたけどいまいちピンとこないな〜。」と思いつつも、敷居の高さや言語の障壁を感じなかなか手を動かすことができていませんでした。
そんなときに偶然出会ったのが、PHP Conference 2017でのこの発表でした。「機械学習界隈にPHPerを流入させて大激震を起こす」と銘打たれた講演に感動した私は、自宅に帰るとすぐにPCに向かってこう打ち込んだのでした。

$ composer require niisan-tokyo/phpnn:dev-master

(PHPerなので敢えてgit cloneせずに光遅い問題を克服しつつcomposerでインストール)
しかし実際に動かしてみたところ、なぜか以下のようなエラーに遭遇してしまいまして、

$ cd phpnn/example
$ php parabola.php

Notice: Undefined offset: 54 in my-path/phpnn/src/layer/Base.php on line 118
PHP Notice:  Undefined offset: 55 in my-path/phpnn/src/layer/Base.php on line 118
...

ソースコードを確認しているうちに、「いっそ自分で作り直そう!」と思い立ち、シンプルなニューラルネットワーク(NN)ライブラリを作成しました。やって見た感じたことは「ニューラルネットワークを支える基本的なアルゴリズム(確率的勾配降下法や誤差逆伝搬法)自体は、かなり単純なものである」ということでした。むしろその難しさは、統計学に基づくモデルの設計や学習のチューニングにあるのではいかと思いました。この記事が、同じように感じていた方のきっかけになればと思います。

TL;DR

こちらが実際に作ったものです。高機能性や柔軟性よりも、あまりクラスを作りすぎずにプレーンな配列を駆使してシンプルに作ったので、コード自体は理解しやすいかと思います。1つだけサンプルモデルも用意したのでお気軽にお試しください(※ PHP 7.1 以上必須)。

実際に動かした結果が以下になります。今回は元のサンプルモデルにもあった、与えられた点がドーナツ型の領域内にあるかどうかの2値問題を学習してみました。

損失関数

image.png

うまく学習が進んでいることが分かります。続いて実際の判定結果です。

判定結果

image.png

茶色のサンプル点がNNが正解と判定した領域、正解領域が黒の線で囲まれた領域になります。入力データを複素ガウス分布にしたため領域の外側がサンプル数が少なくうまく学習できていませんが、比較的良い結果が得られています。正答率でいうと93~94%ほどになります。このように簡単なモデルであれば、そこまでディープなネットワークでなくとも、また細かいチューニングをしなくとも高い性能が得られます。

解説

NNの入門的な記事は、PhpNNの作成者である@niisan-tokyoさんのこの記事や、私が勉強する上で非常に役立った「ニューラルネットワークと深層学習」など随所にあるのでここでは割愛させていただいて、実際にどうアルゴリズムをプログラムに落とし込むのかやプログラムの大まかなフローだけを解説していきます。ちなみに、私は「ニューラルネットワークと深層学習」を読みながらこれを作ったので、もし併読していただければうまく対応するかと思いますし、また非常に分かりやすく説明されている文献なのでオススメです。

学習過程

シミュレーターの根幹部分はこのようになっています。

public function run(): int
{
    // Set training data set.

    // Set testing data set.

    for ($i = 1; $i <= $this->epoch; $i++) {
        // Train the network to fit for the model.
        $trainingLoss = $this->network->train($this->trainingInputSet, $this->trainingOutputSet);

        // Test the network and validate the output of the network.
        [$testingLoss, $validity] = $this->network->test(
            $this->testingInputSet,
            $this->testingOutputSet,
            $this->getValidator()
        );
    }
}

初めにに訓練(training)用のデータセットと検証(testing)用のデータセットを用意し、指定した世代数まで訓練を繰り返し、その結果を検証していきます。したがって Network クラスには、 traintest メソッドが存在します。

train メソッドにおける学習の過程は以下のように、順伝搬(普通にNNにデータを入力し計算させること)と逆伝搬を繰り返し、ミニバッチサイズ毎に重みとバイアスを更新していくというシンプルなものです。

public function train(array $inputSet, array $answerSet): float
{
    $count = 0;
    $loss = 0.0;

    foreach (array_shuffle(array_keys($inputSet)) as $n) {
        $count++;

        // 順伝搬
        $this->forwardPropagate($inputSet[$n]);

        // 逆伝搬
        $this->backwardPropagate($answerSet[$n]);

        // 損失関数を計算
        $loss += $this->lossFunction->loss($this->outputs[$L], $answerSet[$n])
            / (float)$this->batchSize;

        // 学習データ数がミニバッチサイズになったら重みを更新(フィッティング)
        if ($count >= $this->batchSize) {
            $this->update();
            break;
        }
    }

    return $loss;
}

それぞれの関数の中身はアルゴリズムを理解する必要がありますが、プログラムの流れ自体は非常に簡単なものであることがわかります。

ネットワーク設計

NNの設計は、以下のように設定します。

protected $config = [
    'learningRate' => 0.005,       // 学習率
    'batchSize' => 16,             // ミニバッチサイズ
    'numberOfLayers' => 5,         // レイヤー数
    'inputSize' => 2,              // 入力データサイズ
    'outputSize' => 1,             // 出力データサイズ
];

public function setup(): void
{
    // 損失関数(コスト関数)をセット
    $this->network->setLossFunction(new MeanSquareLoss());

    // 各レイヤーの数とアクティベーション関数をセット(レイヤー数が5なので4層を追加)
    $this->network->addLayer(new RectifierNeuron(), 32);
    $this->network->addLayer(new SigmoidNeuron(), 64);
    $this->network->addLayer(new RectifierNeuron(), 32);
    $this->network->addLayer(new TanhNeuron(), 1);
}

NN上の各パラメータ(重み $w^l_{j,k}$, バイアス $b^l_{j}$, 誤差 $\delta^l_{j}$)は、多次元配列としてアクセスします。

$this->weights[$l][$j][$k];
$this->biases[$l][$j];
$this->errors[$l][$j];

まとめ

やはり、自分でコードを書いて動かしてみるのが一番の理解に繋がると実感しました。何度も述べていますが、ニューラルネットワークの基本的なアルゴリズムは非常に簡単なものであることが分かりました。ところで、なぜそんな単純なアルゴリズムでうまくフィッティングができるのかというと、大量のデータを使って、繰り返し大量のパラメータを細かく調整し、任意の関数を近似しているからです(※少し語弊があるかもしれません)。その驚くべきところはむしろ、大量の計算を効率よく行うためのアルゴリズムが非常によくできていることではないかと思います。PHPerのみなさんもぜひ、PHPで機械学習をしてみませんか?(ところでPHPにはすでに機械学習ライブラリ PHP-ML があるようです。こちらもよかったら試してみてください)

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
15