10
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

単純パーセプトロンのソース解読が難解すぎる(僕には)

Last updated at Posted at 2019-05-06

#はじめに

Python機械学習プログラミング第二版の、単純Perceptronのソースコードが難しかったです。
一つ一つ関数を調べながら四時間戦って、意味がわかった気がします。
戻ってきた時のためのメモです!

#クラス内変数定義部分


class Perceptron(object):
    """パーセプトロン分類記.
    パラメータ
    ------------
    eta : float イータと読む
      Learning rate (between 0.0 and 1.0)
    n_iter : int
      Passes over the training dataset.
    random_state : int
      Random number generator seed for random weight
      initialization.
    Attributes
    -----------
    w_ : 1d-array 1次元配列 
    トレーニングデータセットの次元数分+1個入るハズ。
      Weights after fitting. 適合後の重み
    errors_ : list
      Number of misclassifications (updates) in each epoch.

    """
  • eta:学習率、0<eta<1、イータと読む。
  • n_iter:トレーニング回数(エポック数)何回トレーニングさせるか。
  • random_state:??? 後の、重みを初期化する時に0になるのを避ける為っぽいけどよくわかりませんでした。
  • w_:更新後の重みが入る。トレーニングデータの次元数+1個入る。
  • errors_:何回までミスっていいか

#initメソッド


    def __init__(self, eta=0.01, n_iter=50, random_state=1):
        self.eta = eta
        self.n_iter = n_iter
        self.random_state = random_state

このクラスをインスタンス化するときは必ず初期化する。
学習率:0.01
学習の回数:50
ランダム数字:1?とりあえず

#fitメソッド


    def fit(self, X, y):
        """Fit training data.
        Parameters
        ----------
        X : {array-like}, shape = [n_samples, n_features]
          Training vectors, where n_samples is the number of samples and
          n_features is the number of features.
        y : array-like, shape = [n_samples]
          Target values.
        Returns
        -------
        self : object
        """

n_samples:サンプルの個数100個
n_features:特徴量の個数2次元
X:2次元の特徴量が100個入ってる。がく片の長さと、花びらの長さの二つの情報が100個
y:正解のラベルが入っている配列。サンプルの個数分(100個)。


        rgen = np.random.RandomState(self.random_state)
        self.w_ = rgen.normal(loc=0.0, scale=0.01, size=1 + X.shape[1])
        self.errors_ = []

一行目でRandomStateのインスタンスをrgenに代入する。
RandomStateは乱数を発生させるクラスらしい。擬似乱数生成機。
引数は、乱数のシード?らしくて、今回はself.random_state=50に設定。シードよくわからぬ。
そして、一回シードに0を渡して(初期化して)、同じシードを渡すと同じ乱数を出力する。同じ結果を再現できる。
*rgen.randn(10)とすれば、10個の乱数を表示できる

二行目はself.w_に格納されている重みをゼロベクトルに初期化するもの。
インスタンスメソッドnormalは正規分布から無作為にデータを抽出する。
引数loc:分布の中心
引数scale:分布の標準偏差
引数size:出力形状の指定。X.shapeはタプル型で、その配列の2番目はn_features(特徴量の個数)
出力形状は色々あるけど、ここでは、sizeの数分(101個)のデータが出力される。

http://zeema.hatenablog.com/entry/2017/11/05/130031
https://docs.scipy.org/doc/numpy/reference/generated/numpy.random.RandomState.normal.html#numpy.random.RandomState.normal


        for _ in range(self.n_iter):
            errors = 0
            for xi, target in zip(X, y): 
                update = self.eta * (target - self.predict(xi))
                self.w_[1:] += update * xi
                self.w_[0] += update
                errors += int(update != 0.0)
            self.errors_.append(errors)
        return self

for xi, target in zip(X, y):

zip関数:複数のリストからインデックス順に取得。複数の配列を、いい感じに変数に置けるメソッドと理解。
Xとyを前から順番に一つずつとってくる。このときxiとtargetのナンバリングは一致している。

三行目について。Xは特徴量の2次元が100個入った配列。yは正解ラベルが100個入った配列(index)。
つまり、i番目のデータについて、xiは特徴量、targetは正解ラベルが入っている。
▼zip関数について
https://note.nkmk.me/python-zip-usage-for/

self.w_[1:] += update * xi
self.w_[0] += update

self.w_の配列の中に、重みを足し合わせていく。
この時、0番目がバイアスユニット。
1番目が特徴量の1次元目についての重み。w1的なやつ。
2番目が特徴量の2次元目についての重み。w2的なやつ。

▼ここ分からんかったら、このサイト。データ数を3つで考えていて、単純化されていてわかりやすい
https://thinkit.co.jp/article/10342?page=0%2C2

errors += int(update != 0.0)

(update = Δw)
updateが0でないならば int(True)=1を代入。
updateが0であるならば int(False)=0を代入。
errorsは self.w_の状態で、updateが0でない回数を記録する変数。
つまり、updateが全部0なら誤差を修正しきったということ。

self.errors_.append(errors)

self.errors_が0に収束していけば、誤差を修正終了!!
errorsが0だから。

#net_inputメソッド


    def net_input(self, X):
        """Calculate net input"""
        return np.dot(X, self.w_[1:]) + self.w_[0]

net_inputは総入力zを返す関数。
np.dot関数は行列の積や、ベクトルの内積を計算する。
z = w0 + w1 * x1 + w2 * 2... みたいなやつを分解します。

w0 = self.w_[0]
w1 = self.w_[1]
w2 = self.w_[2]
と照らし合わせられる。

Xは12の行列
self.w_[1:]は2
1の行列
なので行列の積が求まる。

よって、
z = np.dot(X,self.w_[1:]) + self.w_[0]
これが総入力!これをreturn

#predictメソッド


    def predict(self, X):
        """Return class label after unit step"""
        return np.where(self.net_input(X) >= 0.0, 1, -1)

returnされた総入力が0.0を超えていれば1をreturn
超えていなければ-1をreturn

#まとめ

十分小さい重み(!=0)で学習スタート

特徴量xiと正解ラベルtargetの差に比例して学習を進める。
(差が大きいほど、updateが大きい)

n_iter回まで学習を進めて、updateが0に収束していたらいい感じ

▼解説した総ソースコード
https://github.com/rasbt/python-machine-learning-book-2nd-edition

10
8
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
10
8

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?