607
435

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

金子勇さんのED法を実装してMNISTを学習させてみた

Last updated at Posted at 2024-04-19

追記:続きを書きました。

その2:ED法を高速化してその性能をMNISTで検証してみた
その3:ED法+交差エントロピーをTF/Torchで実装してみた(おまけでBitNet×ED法を検証)

はじめに

先日以下の記事が投稿され、その斬新な考え方に個人的ながら衝撃を受けました。

内容をざっくり言うと、ニューラルネットワークの学習を現在の主流であるBP法(誤差逆伝播法)ではなく、ED法(誤差拡散法)という新しい学習手法を提案しているものです。
もし記事の内容が本当ならニューラルネットワークの学習がO(1)でできてしまう事になり、まさしく革命が起きてしまいます。
(結論からいうと速度面はそこまででもなかったです(それでも早くなる可能性あり))
(ただこの新手法のポテンシャルは革命を起こす可能性は秘めているといっても過言ではありません)

ED法に関してネットを探すとインターネットアーカイブに情報が少し残っていました。

このページですがED法のサンプルプログラム(C言語)が残っており、このサンプルプログラムをベースにpythonで書き起こしたものが本記事となります。
前半を実装中に分かった範囲での解説、後半にコードを置いておきます。

※私は脳科学の分野もシミュレーションの分野も数学の分野も素人なので、細かい間違いはすいません…。
※本記事では数学的な背景には触れません。(というかできない)あくまで実装メインになります。

ED法(誤差拡散学習法)

概要については上記記事が詳しいのでそちらも是非見てみてください。
アイデアは、BP法の結果を後ろから前に伝播するという動作が実際の脳の動きとしてはおかしい、という所からきているらしいです。
ではどうするかというと「非常に簡単で強力な階層型神経網の学習法則が存在し、それが実際の神経系で使われているはず」と仮定し、これをアミン系の神経伝達物質とします。
アミン系は情報が1対1で伝達されるのではなく、ブロードキャスト的に伝わる特徴があるそうです。
これを再現しようとしたのがED法で、以下の特徴があります。

・出力層の誤差信号(教師信号)をそのまま中間層でも用いる
・興奮性と抑制性の神経細胞を用いて出力を制御する

ちなみにED法はErrorDiffusion法の略らしいです。

※かなり要約しているので厳密性に欠ける点は注意してください

1. ネットワークの構造

BP法ではニューロンは1種類でしたが、前述の通り興奮性ニューロンと抑制性ニューロンの2種類になっています。

fig1.png

"+"が興奮性ニューロンで"-"が抑制性ニューロンです。
"+"から下に伸びてる線が興奮性シナプス、"-"から下に伸びてる線が抑制性シナプスで●がついています。

一番下が出力になり、出力の値を増加したい場合はプラス→プラスのシナプスとマイナス→マイナスのシナプスを強くすれば出力の値が増えます。(矢印の箇所)
※出力を下げる場合は反対側のシナプスを強くする
※なぜか右上だけ逆ですね
※youtubeの動画はこちらの図の説明となります

ただ、これ以降の説明は以下の別の形で行っています。

fig3.png

こちらはプラスで始めるシナプスを強くすれば出力があがる形です。
図3は実際の神経系とは違うようですが、3層構造でも同じルールを適用できるとの事でこちらを使っています。
プログラムコードも図3ベースで実装されているのでこの図3の形を前提に以下説明をしていきます。

2. 学習則

各ニューロンの重みですが、更新式は以下のようです。

8.png

どうしてこの式が出てくるのかは前提となる式の画像が消えており…、分かりません。
各記号に関してですが、$k$は層番号、$i$は接続元のニューロンidx、$j$は接続先のニューロンidx、$w$が重み、$\epsilon$が学習率?、$o^k_j$が接続先のニューロンの出力値、$o^{k-1}_i$が接続元のニューロンの出力値だと思われます。
sign関数はxの符号に応じて{-1,0,1}を返す関数です。
$d$ですが教師データと最終的な出力値の差分で以下となります。

y = 教師データ
o = 最終的な出力値
if y - o > 0:
    d = y - o
else:
    d = o - y

最後に$f'$ですが、出力関数(活性化関数)をシグモイド関数と仮定した場合は以下です。

9.png

また各重みですが、接続元と先が同種の重みは $w^k_{ij} > 0$、異種の場合は $w^k_{ij} < 0$ の制約を持ちます。

実装

C言語の実装では各レイヤーを行列構造で保持し、リカレント型と見なして再帰的に更新しています。
この記事での実装ではTensorflowっぽくニューロンをモジュール単位として実装しました。

ニューロンのイメージは以下です。

aa.drawio.png

"+"と"-"がある以外は既存のニューロンと変わりません。
コードにしかありませんが、"beta"という入力が全ニューロンに追加されていました。(多分biasと同じ効果?)
また、最初の入力値は"+"と"-"に同じ値を分けて使います。(なので必ず入力は2n)
最終的な出力層は"+"ニューロンのみを使います。

1. Neuronクラス

# sigmoidの計算は元コードより、一般的な形と少し違いました。-2/u0 って何でしょうかね?
def sigmoid(x, u0=0.4):
    return 1 / (1 + math.exp(-2 * x / u0))

def sigmoid_derivative(x):
    return sigmoid(x) * (1 - sigmoid(x))

class Neuron:
    def __init__(
        self,
        in_neurons: list["Neuron"],
        ntype: str,          # "p": positive, "n": negative
        alpha: float = 0.8,  # 多分 learning rate
        activation=sigmoid,
        activation_derivative=sigmoid_derivative,
    ) -> None:
        self.ntype = ntype
        self.alpha = alpha
        self.activation = activation
        self.activation_derivative = activation_derivative

        # --- init weights
        # 0~1の乱数で初期化
        # 接続の種類で符号を変える: pp+ pn- np- nn+
        self.weights = []
        for n in in_neurons:
            if ntype == "p":
                if n.ntype == "p":
                    ope = 1
                else:
                    ope = -1
            else:
                if n.ntype == "p":
                    ope = -1
                else:
                    ope = 1
            self.weights.append(random.random() * ope)

        # --- operator
        # 元のowに相当、ここの符号はよくわからなかったので元コードをそのまま再現
        self.operator = 1 if ntype == "p" else -1
        self.weights_operator = [n.operator for n in in_neurons]

        # --- update index
        # 入力元が+ならupper時に学習
        # 入力元が-ならlower時に学習
        self.upper_idx_list = []
        self.lower_idx_list = []
        for i, n in enumerate(in_neurons):
            if n.ntype == "p":
                self.upper_idx_list.append(i)
            else:
                self.lower_idx_list.append(i)

    def forward(self, x):
        # 順方向の計算は既存と同じ
        assert len(self.weights) == len(x)
        y = [x[i] * self.weights[i] for i in range(len(self.weights))]
        y = sum(y)
        self.prev_in = x   # update用に一時保存
        self.prev_out = y  # update用に一時保存
        y = self.activation(y)
        return y

    def update_weight(self, delta_out, direct: str):
        # 誤差拡散による学習、逆伝搬というと怒られそう(笑)

        # f'(o)
        # 元コードではsigmoidを通した後の値を保存して利用することで少し軽量化している
        grad = self.activation_derivative(abs(self.prev_out))

        if direct == "upper":
            indices = self.upper_idx_list
        else:
            indices = self.lower_idx_list

        for idx in indices:
            delta = self.alpha * self.prev_in[idx]
            delta *= grad
            delta *= delta_out * self.operator * self.weights_operator[idx]
            self.weights[idx] += delta

2. Modelクラス

ニューロンを組み合わせた全体のモデルです。
とりあえず3層モデルを試してみました。

class ThreeLayerModel:
    def __init__(
        self,
        input_num: int,
        hidden_num: int,
        alpha: float = 0.8,
        beta: float = 0.8,
    ) -> None:
        self.beta = beta

        # 元コード上は [hd+, hd-] とprintされるもの
        # 多分bias?
        hd_p = Neuron([], "p")
        hd_n = Neuron([], "n")

        # input
        # 入力はpとnそれぞれを作成
        inputs: list[Neuron] = []
        for i in range(input_num):
            inputs.append(Neuron([], "p"))
            inputs.append(Neuron([], "n"))

        # hidden
        # 入力は、[hd+, hd-, in1+, in1-, in2+, in2-, ...]
        self.hidden_neurons: list[Neuron] = []
        for i in range(hidden_num):
            self.hidden_neurons.append(
                Neuron(
                    [hd_p, hd_n] + inputs,
                    ntype=("p" if i % 2 == 1 else "n"),  # 元コードに合わせて-から作成
                    alpha=alpha,
                )
            )

        # output
        # 入力は [hd+, hd-, h1-, h2+, h3-, ...]
        self.out_neuron = Neuron([hd_p, hd_n] + self.hidden_neurons, "p", alpha=alpha)

    def forward(self, inputs):
        # 入力用の配列を作成、入力をp用とn用に複製
        x = []
        for n in inputs:
            x.append(n)  # p
            x.append(n)  # n

        # hidden layerのforward
        # 入力に [hd+, hd-] も追加
        x = [h.forward([self.beta, self.beta] + x) for h in self.hidden_neurons]

        # out layer forward
        # 入力に [hd+, hd-] も追加
        x = self.out_neuron.forward([self.beta, self.beta] + x)

        return x

    def train(self, inputs, target):
        x = self.forward(inputs)

        # --- update(ED)
        # 差分を取得し、更新方向を見る
        diff = target - x
        if diff > 0:
            direct = "upper"
        else:
            direct = "lower"
            diff = -diff
        
        # 各ニューロンを更新
        for h in self.hidden_neurons:
            h.update_weight(diff, direct)
        self.out_neuron.update_weight(diff, direct)

        return diff

3. 実行結果

以下3種類のタスクを見てみました。
XORの結果はサンプルプログラムとほぼ同じ結果になることを確認しています。

# 1. 全ての値で1を返す
#    とりあえず学習できるか見てみるテスト
dataset : 0->1, 1->1

# 2. Notの学習
dataset : 0->1, 1->0

# 3. XORの学習
dataset : [0,0]->1, [1,0]->0, [0,1]->0, [1,1]->1

XORのコード例は以下です。

def main_xor():
    model = ThreeLayerModel(2, hidden_num=16)

    # --- train loop
    dataset = [
        [0, 0, 1.0],
        [1, 0, 0.0],
        [0, 1, 0.0],
        [1, 1, 1.0],
    ]
    for i in range(100):
        x1, x2, target = dataset[random.randint(0, len(dataset)) - 1]
        metric = model.train([x1, x2], target)

        # --- predict
        y = model.forward([x1, x2])
        print(f"{i} in[{x1:5.2f},{x2:5.2f}] -> {y:5.2f}, target {target:5.2f}, metric {metric:5.2f}")

    print("--- result ---")
    for x1, x2, target in dataset:
        y = model.forward([x1, x2])
        print(f"[{x1:5.2f},{x2:5.2f}] -> {y:5.2f}, target {target:5.2f}")

    # --- last weights
    print("--- weights ---")
    print(model.out_neuron)
    for n in model.hidden_neurons:
        print(n)
    1. 全ての値で1を返す
--- result ---
 0.00 ->  0.95, target  1.00
 1.00 ->  0.96, target  1.00
--- weights ---
p  1 [ 1.620( 1,+), -0.824(-1,-), -0.653(-1,-)]
n -1 [-1.045( 1,+),  0.429(-1,-), -0.694( 1,+),  0.206(-1,-)]

0.95とほぼ1になるように学習できていますね。
sigmoidなので完全に1まで学習は難しいはずです。

学習ログ
0 in 1.00 ->  0.53, target  1.00, metric  0.58
1 in 1.00 ->  0.61, target  1.00, metric  0.47
2 in 1.00 ->  0.67, target  1.00, metric  0.39
3 in 0.00 ->  0.59, target  1.00, metric  0.54
4 in 0.00 ->  0.68, target  1.00, metric  0.41
5 in 1.00 ->  0.81, target  1.00, metric  0.20
6 in 1.00 ->  0.82, target  1.00, metric  0.19
7 in 1.00 ->  0.83, target  1.00, metric  0.18
8 in 0.00 ->  0.77, target  1.00, metric  0.27
9 in 1.00 ->  0.86, target  1.00, metric  0.15
10 in 0.00 ->  0.80, target  1.00, metric  0.22
11 in 1.00 ->  0.87, target  1.00, metric  0.13
12 in 1.00 ->  0.88, target  1.00, metric  0.13
13 in 0.00 ->  0.83, target  1.00, metric  0.19
14 in 0.00 ->  0.84, target  1.00, metric  0.17
15 in 0.00 ->  0.85, target  1.00, metric  0.16
16 in 0.00 ->  0.86, target  1.00, metric  0.15
17 in 1.00 ->  0.90, target  1.00, metric  0.10
18 in 0.00 ->  0.87, target  1.00, metric  0.14
19 in 0.00 ->  0.87, target  1.00, metric  0.13
20 in 1.00 ->  0.91, target  1.00, metric  0.09
21 in 1.00 ->  0.91, target  1.00, metric  0.09
22 in 0.00 ->  0.88, target  1.00, metric  0.12
23 in 1.00 ->  0.92, target  1.00, metric  0.08
24 in 0.00 ->  0.89, target  1.00, metric  0.11
25 in 1.00 ->  0.92, target  1.00, metric  0.08
26 in 0.00 ->  0.89, target  1.00, metric  0.11
27 in 1.00 ->  0.92, target  1.00, metric  0.08
28 in 0.00 ->  0.90, target  1.00, metric  0.10
29 in 1.00 ->  0.93, target  1.00, metric  0.07
30 in 0.00 ->  0.90, target  1.00, metric  0.10
31 in 1.00 ->  0.93, target  1.00, metric  0.07
32 in 1.00 ->  0.93, target  1.00, metric  0.07
33 in 1.00 ->  0.93, target  1.00, metric  0.07
34 in 1.00 ->  0.93, target  1.00, metric  0.07
35 in 1.00 ->  0.93, target  1.00, metric  0.07
36 in 1.00 ->  0.93, target  1.00, metric  0.07
37 in 0.00 ->  0.91, target  1.00, metric  0.09
38 in 1.00 ->  0.94, target  1.00, metric  0.07
39 in 0.00 ->  0.91, target  1.00, metric  0.09
40 in 1.00 ->  0.94, target  1.00, metric  0.06
41 in 0.00 ->  0.92, target  1.00, metric  0.08
42 in 1.00 ->  0.94, target  1.00, metric  0.06
43 in 1.00 ->  0.94, target  1.00, metric  0.06
44 in 0.00 ->  0.92, target  1.00, metric  0.08
45 in 0.00 ->  0.92, target  1.00, metric  0.08
46 in 0.00 ->  0.92, target  1.00, metric  0.08
47 in 1.00 ->  0.94, target  1.00, metric  0.06
48 in 1.00 ->  0.94, target  1.00, metric  0.06
49 in 1.00 ->  0.94, target  1.00, metric  0.06
50 in 0.00 ->  0.93, target  1.00, metric  0.08
51 in 1.00 ->  0.94, target  1.00, metric  0.06
52 in 1.00 ->  0.94, target  1.00, metric  0.06
53 in 1.00 ->  0.94, target  1.00, metric  0.06
54 in 0.00 ->  0.93, target  1.00, metric  0.07
55 in 1.00 ->  0.95, target  1.00, metric  0.05
56 in 1.00 ->  0.95, target  1.00, metric  0.05
57 in 1.00 ->  0.95, target  1.00, metric  0.05
58 in 0.00 ->  0.93, target  1.00, metric  0.07
59 in 0.00 ->  0.93, target  1.00, metric  0.07
60 in 1.00 ->  0.95, target  1.00, metric  0.05
61 in 1.00 ->  0.95, target  1.00, metric  0.05
62 in 1.00 ->  0.95, target  1.00, metric  0.05
63 in 1.00 ->  0.95, target  1.00, metric  0.05
64 in 1.00 ->  0.95, target  1.00, metric  0.05
65 in 1.00 ->  0.95, target  1.00, metric  0.05
66 in 0.00 ->  0.94, target  1.00, metric  0.07
67 in 1.00 ->  0.95, target  1.00, metric  0.05
68 in 0.00 ->  0.94, target  1.00, metric  0.06
69 in 1.00 ->  0.95, target  1.00, metric  0.05
70 in 1.00 ->  0.95, target  1.00, metric  0.05
71 in 1.00 ->  0.95, target  1.00, metric  0.05
72 in 1.00 ->  0.95, target  1.00, metric  0.05
73 in 0.00 ->  0.94, target  1.00, metric  0.06
74 in 1.00 ->  0.95, target  1.00, metric  0.05
75 in 1.00 ->  0.95, target  1.00, metric  0.05
76 in 0.00 ->  0.94, target  1.00, metric  0.06
77 in 0.00 ->  0.94, target  1.00, metric  0.06
78 in 0.00 ->  0.94, target  1.00, metric  0.06
79 in 0.00 ->  0.94, target  1.00, metric  0.06
80 in 1.00 ->  0.95, target  1.00, metric  0.05
81 in 1.00 ->  0.96, target  1.00, metric  0.05
82 in 0.00 ->  0.94, target  1.00, metric  0.06
83 in 1.00 ->  0.96, target  1.00, metric  0.04
84 in 1.00 ->  0.96, target  1.00, metric  0.04
85 in 1.00 ->  0.96, target  1.00, metric  0.04
86 in 1.00 ->  0.96, target  1.00, metric  0.04
87 in 1.00 ->  0.96, target  1.00, metric  0.04
88 in 0.00 ->  0.95, target  1.00, metric  0.06
89 in 0.00 ->  0.95, target  1.00, metric  0.05
90 in 0.00 ->  0.95, target  1.00, metric  0.05
91 in 1.00 ->  0.96, target  1.00, metric  0.04
92 in 1.00 ->  0.96, target  1.00, metric  0.04
93 in 1.00 ->  0.96, target  1.00, metric  0.04
94 in 1.00 ->  0.96, target  1.00, metric  0.04
95 in 1.00 ->  0.96, target  1.00, metric  0.04
96 in 0.00 ->  0.95, target  1.00, metric  0.05
97 in 0.00 ->  0.95, target  1.00, metric  0.05
98 in 0.00 ->  0.95, target  1.00, metric  0.05
99 in 0.00 ->  0.95, target  1.00, metric  0.05
    1. Notの学習
--- result ---
 0.00 ->  0.97, target  1.00
 1.00 ->  0.02, target  0.00
--- weights ---
p  1 [ 1.381( 1,+), -0.869(-1,-), -1.356(-1,-),  1.041( 1,+)]
n -1 [-0.837( 1,+),  0.586(-1,-), -0.044( 1,+),  0.949(-1,-)]
p  1 [ 0.932( 1,+), -0.816(-1,-),  0.405( 1,+), -0.851(-1,-)]

0が0.97、1が0.02と学習できていますね。

学習ログ
0 in 0.00 ->  0.71, target  1.00, metric  0.51
1 in 1.00 ->  0.34, target  0.00, metric  0.68
2 in 0.00 ->  0.72, target  1.00, metric  0.54
3 in 1.00 ->  0.22, target  0.00, metric  0.54
4 in 0.00 ->  0.75, target  1.00, metric  0.51
5 in 1.00 ->  0.17, target  0.00, metric  0.39
6 in 1.00 ->  0.11, target  0.00, metric  0.17
7 in 0.00 ->  0.79, target  1.00, metric  0.45
8 in 1.00 ->  0.13, target  0.00, metric  0.22
9 in 0.00 ->  0.84, target  1.00, metric  0.27
10 in 0.00 ->  0.88, target  1.00, metric  0.16
11 in 1.00 ->  0.12, target  0.00, metric  0.21
12 in 0.00 ->  0.88, target  1.00, metric  0.16
13 in 1.00 ->  0.10, target  0.00, metric  0.15
14 in 1.00 ->  0.08, target  0.00, metric  0.10
15 in 0.00 ->  0.89, target  1.00, metric  0.15
16 in 1.00 ->  0.08, target  0.00, metric  0.10
17 in 0.00 ->  0.90, target  1.00, metric  0.13
18 in 0.00 ->  0.92, target  1.00, metric  0.10
19 in 1.00 ->  0.08, target  0.00, metric  0.10
20 in 1.00 ->  0.07, target  0.00, metric  0.08
21 in 1.00 ->  0.06, target  0.00, metric  0.07
22 in 0.00 ->  0.91, target  1.00, metric  0.10
23 in 1.00 ->  0.06, target  0.00, metric  0.06
24 in 0.00 ->  0.92, target  1.00, metric  0.09
25 in 1.00 ->  0.06, target  0.00, metric  0.06
26 in 1.00 ->  0.05, target  0.00, metric  0.06
27 in 0.00 ->  0.93, target  1.00, metric  0.09
28 in 1.00 ->  0.05, target  0.00, metric  0.05
29 in 1.00 ->  0.05, target  0.00, metric  0.05
30 in 1.00 ->  0.04, target  0.00, metric  0.05
31 in 0.00 ->  0.93, target  1.00, metric  0.08
32 in 1.00 ->  0.04, target  0.00, metric  0.05
33 in 1.00 ->  0.04, target  0.00, metric  0.04
34 in 1.00 ->  0.04, target  0.00, metric  0.04
35 in 1.00 ->  0.04, target  0.00, metric  0.04
36 in 0.00 ->  0.93, target  1.00, metric  0.08
37 in 0.00 ->  0.94, target  1.00, metric  0.07
38 in 1.00 ->  0.04, target  0.00, metric  0.04
39 in 0.00 ->  0.94, target  1.00, metric  0.06
40 in 0.00 ->  0.95, target  1.00, metric  0.06
41 in 1.00 ->  0.04, target  0.00, metric  0.04
42 in 0.00 ->  0.95, target  1.00, metric  0.05
43 in 0.00 ->  0.96, target  1.00, metric  0.05
44 in 1.00 ->  0.04, target  0.00, metric  0.04
45 in 0.00 ->  0.96, target  1.00, metric  0.05
46 in 1.00 ->  0.04, target  0.00, metric  0.04
47 in 0.00 ->  0.96, target  1.00, metric  0.04
48 in 0.00 ->  0.96, target  1.00, metric  0.04
49 in 0.00 ->  0.96, target  1.00, metric  0.04
50 in 1.00 ->  0.04, target  0.00, metric  0.04
51 in 1.00 ->  0.04, target  0.00, metric  0.04
52 in 1.00 ->  0.04, target  0.00, metric  0.04
53 in 1.00 ->  0.04, target  0.00, metric  0.04
54 in 1.00 ->  0.03, target  0.00, metric  0.04
55 in 1.00 ->  0.03, target  0.00, metric  0.03
56 in 0.00 ->  0.96, target  1.00, metric  0.04
57 in 1.00 ->  0.03, target  0.00, metric  0.03
58 in 1.00 ->  0.03, target  0.00, metric  0.03
59 in 1.00 ->  0.03, target  0.00, metric  0.03
60 in 1.00 ->  0.03, target  0.00, metric  0.03
61 in 1.00 ->  0.03, target  0.00, metric  0.03
62 in 1.00 ->  0.03, target  0.00, metric  0.03
63 in 1.00 ->  0.03, target  0.00, metric  0.03
64 in 1.00 ->  0.03, target  0.00, metric  0.03
65 in 0.00 ->  0.96, target  1.00, metric  0.05
66 in 1.00 ->  0.03, target  0.00, metric  0.03
67 in 1.00 ->  0.03, target  0.00, metric  0.03
68 in 0.00 ->  0.96, target  1.00, metric  0.05
69 in 0.00 ->  0.96, target  1.00, metric  0.04
70 in 1.00 ->  0.03, target  0.00, metric  0.03
71 in 1.00 ->  0.03, target  0.00, metric  0.03
72 in 0.00 ->  0.96, target  1.00, metric  0.04
73 in 1.00 ->  0.03, target  0.00, metric  0.03
74 in 1.00 ->  0.03, target  0.00, metric  0.03
75 in 1.00 ->  0.02, target  0.00, metric  0.03
76 in 1.00 ->  0.02, target  0.00, metric  0.02
77 in 1.00 ->  0.02, target  0.00, metric  0.02
78 in 0.00 ->  0.96, target  1.00, metric  0.04
79 in 1.00 ->  0.02, target  0.00, metric  0.02
80 in 1.00 ->  0.02, target  0.00, metric  0.02
81 in 0.00 ->  0.96, target  1.00, metric  0.04
82 in 1.00 ->  0.02, target  0.00, metric  0.02
83 in 0.00 ->  0.96, target  1.00, metric  0.04
84 in 1.00 ->  0.02, target  0.00, metric  0.02
85 in 0.00 ->  0.97, target  1.00, metric  0.04
86 in 1.00 ->  0.02, target  0.00, metric  0.02
87 in 0.00 ->  0.97, target  1.00, metric  0.03
88 in 1.00 ->  0.02, target  0.00, metric  0.02
89 in 1.00 ->  0.02, target  0.00, metric  0.02
90 in 1.00 ->  0.02, target  0.00, metric  0.02
91 in 1.00 ->  0.02, target  0.00, metric  0.02
92 in 0.00 ->  0.97, target  1.00, metric  0.03
93 in 1.00 ->  0.02, target  0.00, metric  0.02
94 in 1.00 ->  0.02, target  0.00, metric  0.02
95 in 1.00 ->  0.02, target  0.00, metric  0.02
96 in 0.00 ->  0.97, target  1.00, metric  0.03
97 in 0.00 ->  0.97, target  1.00, metric  0.03
98 in 1.00 ->  0.02, target  0.00, metric  0.02
99 in 1.00 ->  0.02, target  0.00, metric  0.02
    1. XORの学習
--- result ---
[ 0.00, 0.00] ->  0.98, target  1.00
[ 1.00, 0.00] ->  0.04, target  0.00
[ 0.00, 1.00] ->  0.03, target  0.00
[ 1.00, 1.00] ->  0.98, target  1.00
--- weights ---
p  1 [ 0.959( 1,+), -1.139(-1,-), -0.967(-1,-),  0.441( 1,+), -0.939(-1,-),  1.061( 1,+), -1.146(-1,-),  0.889( 1,+), -0.727(-1,-),  1.076( 1,+), -0.368(-1,-),  1.485( 1,+), -1.143(-1,-),  0.811( 1,+), -0.629(-1,-),  1.252( 1,+), -1.554(-1,-),  1.195( 1,+)]
n -1 [-1.311( 1,+),  1.172(-1,-), -0.463( 1,+),  1.101(-1,-), -0.224( 1,+),  0.802(-1,-)]
p  1 [ 0.238( 1,+), -0.908(-1,-),  0.027( 1,+), -0.452(-1,-),  0.427( 1,+), -0.510(-1,-)]
n -1 [-1.133( 1,+),  0.651(-1,-), -0.921( 1,+),  1.336(-1,-), -0.973( 1,+),  0.373(-1,-)]
(省略)

ニューロン数を16まで増やしました。
ちゃんと学習できていますね。

学習ログ
0 in[ 1.00, 1.00] ->  1.00, target  1.00, metric  0.00
1 in[ 1.00, 1.00] ->  1.00, target  1.00, metric  0.00
2 in[ 1.00, 1.00] ->  1.00, target  1.00, metric  0.00
3 in[ 0.00, 0.00] ->  0.96, target  1.00, metric  0.06
4 in[ 1.00, 0.00] ->  0.89, target  0.00, metric  0.99
5 in[ 0.00, 0.00] ->  0.92, target  1.00, metric  0.18
6 in[ 1.00, 1.00] ->  0.99, target  1.00, metric  0.01
7 in[ 1.00, 1.00] ->  0.99, target  1.00, metric  0.01
8 in[ 0.00, 0.00] ->  0.95, target  1.00, metric  0.07
9 in[ 1.00, 1.00] ->  1.00, target  1.00, metric  0.00
10 in[ 1.00, 1.00] ->  1.00, target  1.00, metric  0.00
11 in[ 1.00, 0.00] ->  0.35, target  0.00, metric  0.95
12 in[ 0.00, 0.00] ->  0.94, target  1.00, metric  0.30
13 in[ 1.00, 0.00] ->  0.01, target  0.00, metric  0.70
14 in[ 1.00, 1.00] ->  1.00, target  1.00, metric  0.60
15 in[ 0.00, 1.00] ->  0.66, target  0.00, metric  1.00
16 in[ 1.00, 1.00] ->  0.95, target  1.00, metric  0.08
17 in[ 0.00, 1.00] ->  0.00, target  0.00, metric  0.73
18 in[ 1.00, 1.00] ->  0.92, target  1.00, metric  0.96
19 in[ 1.00, 0.00] ->  0.05, target  0.00, metric  0.17
20 in[ 0.00, 0.00] ->  0.79, target  1.00, metric  0.94
21 in[ 0.00, 1.00] ->  0.01, target  0.00, metric  0.37
22 in[ 0.00, 0.00] ->  0.98, target  1.00, metric  0.77
23 in[ 0.00, 0.00] ->  0.98, target  1.00, metric  0.02
24 in[ 0.00, 1.00] ->  0.02, target  0.00, metric  0.25
25 in[ 0.00, 1.00] ->  0.01, target  0.00, metric  0.02
26 in[ 1.00, 0.00] ->  0.00, target  0.00, metric  0.45
27 in[ 1.00, 1.00] ->  0.96, target  1.00, metric  0.83
28 in[ 0.00, 0.00] ->  0.96, target  1.00, metric  0.06
29 in[ 0.00, 0.00] ->  0.96, target  1.00, metric  0.04
30 in[ 1.00, 1.00] ->  0.97, target  1.00, metric  0.03
31 in[ 0.00, 1.00] ->  0.04, target  0.00, metric  0.11
32 in[ 0.00, 0.00] ->  0.96, target  1.00, metric  0.06
33 in[ 1.00, 1.00] ->  0.96, target  1.00, metric  0.05
34 in[ 0.00, 1.00] ->  0.03, target  0.00, metric  0.05
35 in[ 1.00, 1.00] ->  0.96, target  1.00, metric  0.05
36 in[ 0.00, 1.00] ->  0.03, target  0.00, metric  0.04
37 in[ 1.00, 1.00] ->  0.96, target  1.00, metric  0.05
38 in[ 0.00, 0.00] ->  0.96, target  1.00, metric  0.04
39 in[ 1.00, 1.00] ->  0.97, target  1.00, metric  0.04
40 in[ 0.00, 1.00] ->  0.03, target  0.00, metric  0.04
41 in[ 1.00, 0.00] ->  0.00, target  0.00, metric  0.49
42 in[ 1.00, 0.00] ->  0.00, target  0.00, metric  0.00
43 in[ 1.00, 1.00] ->  0.99, target  1.00, metric  0.61
44 in[ 1.00, 1.00] ->  0.99, target  1.00, metric  0.01
45 in[ 1.00, 0.00] ->  0.02, target  0.00, metric  0.26
46 in[ 1.00, 1.00] ->  0.96, target  1.00, metric  0.05
47 in[ 1.00, 1.00] ->  0.96, target  1.00, metric  0.04
48 in[ 0.00, 1.00] ->  0.02, target  0.00, metric  0.03
49 in[ 1.00, 1.00] ->  0.96, target  1.00, metric  0.04
50 in[ 0.00, 0.00] ->  0.95, target  1.00, metric  0.10
51 in[ 1.00, 1.00] ->  0.97, target  1.00, metric  0.03
52 in[ 1.00, 0.00] ->  0.03, target  0.00, metric  0.05
53 in[ 1.00, 1.00] ->  0.97, target  1.00, metric  0.03
54 in[ 0.00, 0.00] ->  0.96, target  1.00, metric  0.06
55 in[ 0.00, 1.00] ->  0.03, target  0.00, metric  0.06
56 in[ 0.00, 0.00] ->  0.96, target  1.00, metric  0.05
57 in[ 1.00, 1.00] ->  0.97, target  1.00, metric  0.03
58 in[ 1.00, 1.00] ->  0.97, target  1.00, metric  0.03
59 in[ 0.00, 1.00] ->  0.03, target  0.00, metric  0.05
60 in[ 1.00, 1.00] ->  0.97, target  1.00, metric  0.03
61 in[ 1.00, 1.00] ->  0.97, target  1.00, metric  0.03
62 in[ 0.00, 1.00] ->  0.02, target  0.00, metric  0.03
63 in[ 1.00, 1.00] ->  0.97, target  1.00, metric  0.03
64 in[ 0.00, 0.00] ->  0.97, target  1.00, metric  0.04
65 in[ 1.00, 0.00] ->  0.04, target  0.00, metric  0.06
66 in[ 1.00, 0.00] ->  0.03, target  0.00, metric  0.04
67 in[ 1.00, 1.00] ->  0.97, target  1.00, metric  0.03
68 in[ 0.00, 0.00] ->  0.97, target  1.00, metric  0.04
69 in[ 1.00, 1.00] ->  0.97, target  1.00, metric  0.03
70 in[ 1.00, 1.00] ->  0.98, target  1.00, metric  0.03
71 in[ 1.00, 1.00] ->  0.98, target  1.00, metric  0.02
72 in[ 1.00, 1.00] ->  0.98, target  1.00, metric  0.02
73 in[ 1.00, 1.00] ->  0.98, target  1.00, metric  0.02
74 in[ 0.00, 1.00] ->  0.03, target  0.00, metric  0.04
75 in[ 1.00, 1.00] ->  0.98, target  1.00, metric  0.02
76 in[ 1.00, 1.00] ->  0.98, target  1.00, metric  0.02
77 in[ 0.00, 1.00] ->  0.02, target  0.00, metric  0.03
78 in[ 1.00, 0.00] ->  0.03, target  0.00, metric  0.05
79 in[ 0.00, 0.00] ->  0.97, target  1.00, metric  0.04
80 in[ 1.00, 1.00] ->  0.98, target  1.00, metric  0.03
81 in[ 1.00, 0.00] ->  0.03, target  0.00, metric  0.04
82 in[ 1.00, 0.00] ->  0.02, target  0.00, metric  0.03
83 in[ 0.00, 1.00] ->  0.02, target  0.00, metric  0.02
84 in[ 0.00, 0.00] ->  0.97, target  1.00, metric  0.04
85 in[ 1.00, 1.00] ->  0.97, target  1.00, metric  0.03
86 in[ 0.00, 0.00] ->  0.97, target  1.00, metric  0.03
87 in[ 0.00, 0.00] ->  0.98, target  1.00, metric  0.03
88 in[ 1.00, 0.00] ->  0.03, target  0.00, metric  0.03
89 in[ 0.00, 0.00] ->  0.98, target  1.00, metric  0.02
90 in[ 0.00, 0.00] ->  0.98, target  1.00, metric  0.02
91 in[ 0.00, 0.00] ->  0.98, target  1.00, metric  0.02
92 in[ 0.00, 1.00] ->  0.02, target  0.00, metric  0.03
93 in[ 1.00, 1.00] ->  0.97, target  1.00, metric  0.03
94 in[ 1.00, 1.00] ->  0.97, target  1.00, metric  0.03
95 in[ 1.00, 1.00] ->  0.98, target  1.00, metric  0.03
96 in[ 1.00, 1.00] ->  0.98, target  1.00, metric  0.02
97 in[ 0.00, 1.00] ->  0.02, target  0.00, metric  0.03
98 in[ 1.00, 1.00] ->  0.98, target  1.00, metric  0.02
99 in[ 1.00, 1.00] ->  0.98, target  1.00, metric  0.02
--- weights ---
p  1 [ 0.959( 1,+), -1.139(-1,-), -0.967(-1,-),  0.441( 1,+), -0.939(-1,-),  1.061( 1,+), -1.146(-1,-),  0.889( 1,+), -0.727(-1,-),  1.076( 1,+), -0.368(-1,-),  1.485( 1,+), -1.143(-1,-),  0.811( 1,+), -0.629(-1,-),  1.252( 1,+), -1.554(-1,-),  1.195( 1,+)]
n -1 [-1.311( 1,+),  1.172(-1,-), -0.463( 1,+),  1.101(-1,-), -0.224( 1,+),  0.802(-1,-)]
p  1 [ 0.238( 1,+), -0.908(-1,-),  0.027( 1,+), -0.452(-1,-),  0.427( 1,+), -0.510(-1,-)]
n -1 [-1.133( 1,+),  0.651(-1,-), -0.921( 1,+),  1.336(-1,-), -0.973( 1,+),  0.373(-1,-)]
p  1 [ 0.849( 1,+), -0.910(-1,-),  0.262( 1,+), -0.878(-1,-),  0.535( 1,+), -0.773(-1,-)]
n -1 [-0.750( 1,+),  1.223(-1,-), -0.469( 1,+),  0.775(-1,-), -0.414( 1,+),  0.535(-1,-)]
p  1 [ 1.588( 1,+), -1.405(-1,-),  0.852( 1,+), -0.269(-1,-),  0.805( 1,+), -0.871(-1,-)]
n -1 [-1.138( 1,+),  1.228(-1,-), -0.193( 1,+),  0.746(-1,-), -0.440( 1,+),  0.817(-1,-)]
p  1 [ 1.338( 1,+), -1.159(-1,-),  0.668( 1,+), -1.056(-1,-),  1.013( 1,+), -0.141(-1,-)]
n -1 [-1.256( 1,+),  1.117(-1,-), -1.113( 1,+),  0.810(-1,-), -0.815( 1,+),  0.916(-1,-)]
p  1 [ 0.976( 1,+), -0.297(-1,-),  0.638( 1,+), -0.088(-1,-),  0.490( 1,+), -0.860(-1,-)]
n -1 [-1.565( 1,+),  1.764(-1,-), -1.050( 1,+),  1.221(-1,-), -0.638( 1,+),  0.774(-1,-)]
p  1 [ 1.072( 1,+), -0.398(-1,-),  1.013( 1,+), -0.597(-1,-),  0.851( 1,+), -1.061(-1,-)]
n -1 [-1.501( 1,+),  1.038(-1,-), -0.975( 1,+),  0.096(-1,-), -0.763( 1,+),  1.352(-1,-)]
p  1 [ 0.906( 1,+), -1.575(-1,-),  1.281( 1,+), -0.749(-1,-),  1.297( 1,+), -0.885(-1,-)]
n -1 [-0.349( 1,+),  0.825(-1,-), -0.548( 1,+),  0.891(-1,-), -0.100( 1,+),  0.481(-1,-)]
p  1 [ 0.962( 1,+), -0.137(-1,-),  0.937( 1,+), -0.456(-1,-),  0.484( 1,+), -0.537(-1,-)]

コード全体

コード全体
import math
import random

random.seed(10)


def sign(x):
    if x == 0:
        return 0
    return 1 if x > 0 else -1


def relu(x):
    return x if x > 0 else 0


def relu_derivative(x):
    return 1 if x > 0 else 0


def sigmoid(x, u0=0.4):
    return 1 / (1 + math.exp(-2 * x / u0))


def sigmoid_derivative(x):
    return sigmoid(x) * (1 - sigmoid(x))


def linear(x):
    return x


def linear_derivative(x):
    return 1


class Neuron:
    def __init__(
        self,
        in_neurons: list["Neuron"],
        ntype: str,  # "p": positive, "n": negative
        alpha: float = 0.8,  # 多分 learning rate
        activation=sigmoid,
        activation_derivative=sigmoid_derivative,
    ) -> None:
        self.ntype = ntype
        self.alpha = alpha
        self.activation = activation
        self.activation_derivative = activation_derivative

        # --- init weights
        # weight: pp+ pn- np- nn+
        self.weights = []
        for n in in_neurons:
            if ntype == "p":
                if n.ntype == "p":
                    ope = 1
                else:
                    ope = -1
            else:
                if n.ntype == "p":
                    ope = -1
                else:
                    ope = 1
            self.weights.append(random.random() * ope)

        # --- operator
        self.operator = 1 if ntype == "p" else -1
        self.weights_operator = [n.operator for n in in_neurons]

        # --- update index
        # 入力元が+ならupper時に学習
        # 入力元が-ならlower時に学習
        self.upper_idx_list = []
        self.lower_idx_list = []
        for i, n in enumerate(in_neurons):
            if n.ntype == "p":
                self.upper_idx_list.append(i)
            else:
                self.lower_idx_list.append(i)

    def forward(self, x):
        assert len(self.weights) == len(x)
        y = [x[i] * self.weights[i] for i in range(len(self.weights))]
        y = sum(y)
        self.prev_in = x
        self.prev_out = y
        y = self.activation(y)
        return y

    def update_weight(self, delta_out, direct: str):
        grad = self.activation_derivative(abs(self.prev_out))

        if direct == "upper":
            indices = self.upper_idx_list
        else:
            indices = self.lower_idx_list

        for idx in indices:
            _old_w = self.weights[idx]
            delta = self.alpha * self.prev_in[idx]
            delta *= grad
            delta *= delta_out * self.operator * self.weights_operator[idx]
            self.weights[idx] += delta

            # --- debug
            s = f"{idx:2d}"
            s += f", ot_in {self.prev_in[idx]:5.2f}"
            s += f", f'({self.prev_out:5.2f})={grad:5.2f}"
            s += f", del_ot {delta_out:5.2f}"
            s += f", d {delta:6.3f}"
            s += f", {self.operator:2d} {self.weights_operator[idx]:2d}"
            s += f", w {_old_w:5.2f} -> {self.weights[idx]:5.2f}"
            # print(s)

    def __str__(self):
        s = f"{self.ntype} {self.operator:2d}"
        arr = []
        for i in range(len(self.weights)):
            o = "+" if i in self.upper_idx_list else "-"
            arr.append(f"{self.weights[i]:6.3f}({self.weights_operator[i]:2d},{o})")
        s += " [" + ", ".join(arr) + "]"
        return s


class ThreeLayerModel:
    def __init__(
        self,
        input_num: int,
        hidden_num: int,
        alpha: float = 0.8,
        beta: float = 0.8,
    ) -> None:
        self.beta = beta

        # [hd+, hd-] (bias?)
        hd_p = Neuron([], "p")
        hd_n = Neuron([], "n")

        # input
        inputs: list[Neuron] = []
        for i in range(input_num):
            inputs.append(Neuron([], "p"))
            inputs.append(Neuron([], "n"))

        # hidden
        self.hidden_neurons: list[Neuron] = []
        for i in range(hidden_num):
            self.hidden_neurons.append(
                Neuron(
                    [hd_p, hd_n] + inputs,
                    ntype=("p" if i % 2 == 1 else "n"),
                    alpha=alpha,
                    activation=sigmoid,
                    activation_derivative=sigmoid_derivative,
                )
            )

        # output
        self.out_neuron = Neuron(
            [hd_p, hd_n] + self.hidden_neurons,
            "p",
            alpha=alpha,
            activation=sigmoid,
            activation_derivative=sigmoid_derivative,
        )

    def forward(self, inputs):
        # in layer
        x = []
        for n in inputs:
            x.append(n)  # p
            x.append(n)  # n

        # hidden layer
        x = [h.forward([self.beta, self.beta] + x) for h in self.hidden_neurons]

        # out layer
        x = self.out_neuron.forward([self.beta, self.beta] + x)

        return x

    def train(self, inputs, target):
        x = self.forward(inputs)

        # --- update(ED)
        diff = target - x
        if diff > 0:
            direct = "upper"
        else:
            direct = "lower"
            diff = -diff
        self.out_neuron.update_weight(diff, direct)
        for h in self.hidden_neurons:
            h.update_weight(diff, direct)

        return diff


def main_one():
    model = ThreeLayerModel(1, hidden_num=1)

    # --- train loop
    dataset = [
        [0, 1],
        [1, 1],
    ]
    for i in range(100):
        x, target = dataset[random.randint(0, len(dataset)) - 1]
        metric = model.train([x], target)

        # --- predict
        y = model.forward([x])
        print(f"{i} in{x:5.2f} -> {y:5.2f}, target {target:5.2f}, metric {metric:5.2f}")

    print("--- result ---")
    for x, target in dataset:
        y = model.forward([x])
        print(f"{x:5.2f} -> {y:5.2f}, target {target:5.2f}")

    # --- last weights
    print("--- weights ---")
    print(model.out_neuron)
    for n in model.hidden_neurons:
        print(n)


def main_not():
    model = ThreeLayerModel(1, hidden_num=2)

    # --- train loop
    dataset = [
        [0, 1],
        [1, 0],
    ]
    for i in range(100):
        x, target = dataset[random.randint(0, len(dataset)) - 1]
        metric = model.train([x], target)

        # --- predict
        y = model.forward([x])
        print(f"{i} in{x:5.2f} -> {y:5.2f}, target {target:5.2f}, metric {metric:5.2f}")

    print("--- result ---")
    for x, target in dataset:
        y = model.forward([x])
        print(f"{x:5.2f} -> {y:5.2f}, target {target:5.2f}")

    # --- last weights
    print("--- weights ---")
    print(model.out_neuron)
    for n in model.hidden_neurons:
        print(n)


def main_xor():
    model = ThreeLayerModel(2, hidden_num=16)

    # --- train loop
    dataset = [
        [0, 0, 1.0],
        [1, 0, 0.0],
        [0, 1, 0.0],
        [1, 1, 1.0],
    ]
    for i in range(100):
        x1, x2, target = dataset[random.randint(0, len(dataset)) - 1]
        metric = model.train([x1, x2], target)

        # --- predict
        y = model.forward([x1, x2])
        print(f"{i} in[{x1:5.2f},{x2:5.2f}] -> {y:5.2f}, target {target:5.2f}, metric {metric:5.2f}")

    print("--- result ---")
    for x1, x2, target in dataset:
        y = model.forward([x1, x2])
        print(f"[{x1:5.2f},{x2:5.2f}] -> {y:5.2f}, target {target:5.2f}")

    # --- last weights
    print("--- weights ---")
    print(model.out_neuron)
    for n in model.hidden_neurons:
        print(n)


if __name__ == "__main__":
    main_one()
    main_not()
    main_xor()

MNISTの学習

MNISTの学習もできちゃいました…。
自分の実装コードに自信がなかったので流石にできないかな?と思ってやる予定はなかったのですが、やってみたら出来てしまった…。

MNISTですが、0と1の学習データだけにしてデータ数も1000まで減らしています。
また、同じモデルのTensorflowも比較用に学習させています。

モデルですがED法の多層モデルも見てみたかったので10層にしてみました。
(MultiLayerModelというクラスを新しく作っています)

Tensorflowだと以下のモデルになります。

model = tf.keras.models.Sequential(
    [
        tf.keras.layers.Input(shape=(28 * 28)),
        tf.keras.layers.Dense(16, activation="sigmoid"),
        tf.keras.layers.Dense(16, activation="sigmoid"),
        tf.keras.layers.Dense(16, activation="sigmoid"),
        tf.keras.layers.Dense(16, activation="sigmoid"),
        tf.keras.layers.Dense(16, activation="sigmoid"),
        tf.keras.layers.Dense(16, activation="sigmoid"),
        tf.keras.layers.Dense(16, activation="sigmoid"),
        tf.keras.layers.Dense(16, activation="sigmoid"),
        tf.keras.layers.Dense(16, activation="sigmoid"),
        tf.keras.layers.Dense(16, activation="sigmoid"),
        tf.keras.layers.Dense(1, activation="sigmoid"),
    ]
)

結果は以下です。

# Tensorflow
1000/1000 [==============================] - 2s 2ms/step - loss: 0.7006 - accuracy: 0.5280
67/67 [==============================] - 0s 1ms/step - loss: 0.6995 - accuracy: 0.5366

# ED法
98.35%

ED法で約98%の予測に成功しています。
これは学習できているといってもいい数値ですね!これBP法じゃないんだぜ…。

また、10層だとTensorflowは学習できていません。(約54%程の精度)
これは勾配消失が起きて学習できていないからだと思われます。(5層に減らすとTensorflowでも学習できる)

学習中のmetrics(正解ラベルと出力値の差)は以下でした。

Figure_1.png

後半はほぼ0になってちゃんと学習できているのが分かります。
冗長だけどコードも貼っておきます。

コード全体
import math
import random

import numpy as np
import pandas as pd
import tensorflow as tf
from matplotlib import pyplot as plt
from tqdm import tqdm

random.seed(10)


def sigmoid(x, u0=0.4):
    return 1 / (1 + math.exp(-2 * x / u0))


def sigmoid_derivative(x):
    return sigmoid(x) * (1 - sigmoid(x))


def linear(x):
    return x


def linear_derivative(x):
    return 1


class Neuron:
    def __init__(
        self,
        in_neurons: list["Neuron"],
        ntype: str,  # "p": positive, "n": negative
        alpha: float = 0.8,  # 多分 learning rate
        activation=sigmoid,
        activation_derivative=sigmoid_derivative,
    ) -> None:
        self.ntype = ntype
        self.alpha = alpha
        self.activation = activation
        self.activation_derivative = activation_derivative

        # --- init weights
        # weight: pp+ pn- np- nn+
        self.weights = []
        for n in in_neurons:
            if ntype == "p":
                if n.ntype == "p":
                    ope = 1
                else:
                    ope = -1
            else:
                if n.ntype == "p":
                    ope = -1
                else:
                    ope = 1
            self.weights.append(random.random() * ope)

        # --- operator
        self.operator = 1 if ntype == "p" else -1
        self.weights_operator = [n.operator for n in in_neurons]

        # --- update index
        # 入力元が+ならupper時に学習
        # 入力元が-ならlower時に学習
        self.upper_idx_list = []
        self.lower_idx_list = []
        for i, n in enumerate(in_neurons):
            if n.ntype == "p":
                self.upper_idx_list.append(i)
            else:
                self.lower_idx_list.append(i)

    def forward(self, x):
        assert len(self.weights) == len(x)
        y = [x[i] * self.weights[i] for i in range(len(self.weights))]
        y = sum(y)
        self.prev_in = x
        self.prev_out = y
        y = self.activation(y)
        return y

    def update_weight(self, delta_out, direct: str):
        grad = self.activation_derivative(abs(self.prev_out))

        if direct == "upper":
            indices = self.upper_idx_list
        else:
            indices = self.lower_idx_list

        for idx in indices:
            _old_w = self.weights[idx]
            delta = self.alpha * self.prev_in[idx]
            delta *= grad
            delta *= delta_out * self.operator * self.weights_operator[idx]
            self.weights[idx] += delta

            # --- debug
            s = f"{idx:2d}"
            s += f", ot_in {self.prev_in[idx]:5.2f}"
            s += f", f'({self.prev_out:5.2f})={grad:5.2f}"
            s += f", del_ot {delta_out:5.2f}"
            s += f", d {delta:6.3f}"
            s += f", {self.operator:2d} {self.weights_operator[idx]:2d}"
            s += f", w {_old_w:5.2f} -> {self.weights[idx]:5.2f}"
            # print(s)

    def __str__(self):
        s = f"{self.ntype} {self.operator:2d}"
        arr = []
        for i in range(len(self.weights)):
            o = "+" if i in self.upper_idx_list else "-"
            arr.append(f"{self.weights[i]:6.3f}({self.weights_operator[i]:2d},{o})")
        s += " [" + ", ".join(arr) + "]"
        return s


class MultiLayerModel:
    def __init__(
        self,
        input_num: int,
        hidden_sizes,
        alpha: float = 0.8,
        beta: float = 0.8,
    ) -> None:
        self.beta = beta

        # [hd+, hd-] (bias?)
        hd_p = Neuron([], "p")
        hd_n = Neuron([], "n")

        # input
        inputs: list[Neuron] = []
        for i in range(input_num):
            inputs.append(Neuron([], "p"))
            inputs.append(Neuron([], "n"))

        # hidden
        self.hidden_neurons_list: list[list[Neuron]] = []
        idx = 0
        prev_neurons = inputs
        for size in hidden_sizes:
            hidden_neurons = []
            for i in range(size):
                hidden_neurons.append(
                    Neuron(
                        [hd_p, hd_n] + prev_neurons,
                        ntype=("p" if idx % 2 == 0 else "n"),
                        alpha=alpha,
                        activation=sigmoid,
                        activation_derivative=sigmoid_derivative,
                    )
                )
                idx += 1
            prev_neurons = hidden_neurons
            self.hidden_neurons_list.append(hidden_neurons)

        # output
        self.out_neuron = Neuron(
            [hd_p, hd_n] + self.hidden_neurons_list[-1],
            "p",
            alpha=alpha,
            activation=sigmoid,
            activation_derivative=sigmoid_derivative,
        )

    def forward(self, inputs):
        # in layer
        x = []
        for n in inputs:
            x.append(n)  # p
            x.append(n)  # n

        # hidden layer
        for neurons in self.hidden_neurons_list:
            x = [h.forward([self.beta, self.beta] + x) for h in neurons]

        # out layer
        x = self.out_neuron.forward([self.beta, self.beta] + x)

        return x

    def train(self, inputs, target):
        x = self.forward(inputs)

        # --- update(ED)
        diff = target - x
        if diff > 0:
            direct = "upper"
        else:
            direct = "lower"
            diff = -diff

        # update
        for neurons in self.hidden_neurons_list:
            for n in neurons:
                n.update_weight(diff, direct)
        self.out_neuron.update_weight(diff, direct)

        return diff


def _create_dataset():
    mnist = tf.keras.datasets.mnist
    (x_train, y_train), (x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    x_train = x_train.reshape(x_train.shape[0], -1)
    x_test = x_test.reshape(x_test.shape[0], -1)
    train_indices = np.where((y_train == 0) | (y_train == 1))[0]
    test_indices = np.where((y_test == 0) | (y_test == 1))[0]
    x_train = x_train[train_indices]
    y_train = y_train[train_indices]
    x_test = x_test[test_indices]
    y_test = y_test[test_indices]
    # データ数が多いので削減
    x_train = x_train[:1000]
    y_train = y_train[:1000]
    return (x_train, y_train), (x_test, y_test)


def main_tf():
    (x_train, y_train), (x_test, y_test) = _create_dataset()

    model = tf.keras.models.Sequential(
        [
            tf.keras.layers.Input(shape=(28 * 28)),
            tf.keras.layers.Dense(16, activation="sigmoid"),
            tf.keras.layers.Dense(16, activation="sigmoid"),
            tf.keras.layers.Dense(16, activation="sigmoid"),
            tf.keras.layers.Dense(16, activation="sigmoid"),
            tf.keras.layers.Dense(16, activation="sigmoid"),
            tf.keras.layers.Dense(16, activation="sigmoid"),
            tf.keras.layers.Dense(16, activation="sigmoid"),
            tf.keras.layers.Dense(16, activation="sigmoid"),
            tf.keras.layers.Dense(16, activation="sigmoid"),
            tf.keras.layers.Dense(16, activation="sigmoid"),
            tf.keras.layers.Dense(1, activation="sigmoid"),
        ]
    )
    model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
    model.fit(x_train, y_train, epochs=1, batch_size=1)
    model.evaluate(x_test, y_test)


def main_ed():
    (x_train, y_train), (x_test, y_test) = _create_dataset()

    model = MultiLayerModel(28 * 28, (16, 16, 16, 16, 16, 16, 16, 16, 16, 16), alpha=0.1)

    # --- train loop
    metrics = []
    for i in range(1):
        for j in tqdm(range(len(x_train))):
            x = x_train[j]
            target = y_train[j]
            metric = model.train(x, target)
            metrics.append(metric)

    correct = 0
    total = 0
    for i in tqdm(range(len(x_test))):
        y = model.forward(x_test[i])
        y = 1 if y > 0.5 else 0
        if y_test[i] == y:
            correct += 1
        total += 1
    print(f"{100 * correct / total:.2f}%")

    plt.plot(pd.DataFrame(metrics).rolling(20).mean())
    plt.plot(metrics, alpha=0.2)
    plt.grid()
    plt.xlabel("step")
    plt.ylabel("diff")
    plt.show()


if __name__ == "__main__":
    main_tf()
    main_ed()

最後に

これが1999年に考えられていたなんて衝撃です。
ディープラーニングの誕生が2006年らしいので当時は層を重ねるという発想もまだなかった時代ですね。
ED法もそこに関しては同じで3層+sigmoidがメインですけど今の環境でも通用しそうな凄さがあります。
MNISTでは10層で学習ができたので多層モデルにも期待が持てますね。

  • サンプルプログラムに関して
    久しぶりにC言語を触りました。今の時代実行環境を作るのも簡単で助かりますね。
    また、サンプルプログラムがかなり分かりやすい書き方で助かりました。
    (pythonで実装しなおす必要あるのか?と少し思ったり…)

  • ED法に関して
    BP法じゃない方法でちゃんと学習できている事に驚きました。
    実装時に感じた所感は以下です。

    • weighsの更新は結局全てに対して行われる
       →ただ、BP法みたいに前(後?)を待つ必要がないので、全weightsを並列で一括更新できる
    • 出力は1つだけっぽい、元コードはそれぞれの出力に対して別の重みを使用
       →重みのサイズは増えるけどその分早くなるなら許容できそう
       →そもそも複数出力に対しても案外学習できそう?
  • ED法に関して2
    個人的に気になるのは現環境におけるBP法の代替手段になりえるのか?です。
    ED法はBP法とは違った学習になるので、もし研究が進んでBP法の代替になりえたなら革命が起きるかもしれませんね。
    とりあえずはBP法と同じ道をたどれるかが課題でしょうか。
    ・sigmoid以外の活性化関数に対応できるか
    ・ミニバッチ学習、ドロップアウト、損失関数等々(ドロップアウトっぽい事はyoutubeのスライドに書いてあったので効果ありそう)
    ・RNNやAttention/Transformerへ応用できるか
    等々

ちなみに私は映画Winnyをまだ見ていません。
今度見てみようかな

607
435
25

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
607
435

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?