1
1

More than 1 year has passed since last update.

ナイーブベイズの実装確認

Posted at

今回調べたこと

ナイーブベイズを実装しようとした際に、最終的な予測の出力で迷った部分があったので、sklearnの実装を参考。
その際に実装の中身を追ったので、内容を記録。

ソースコード

参考コード
今回はBernoulliNBの実装を参考に中身を確認。

公式のサンプルコード
import numpy as np
from sklearn.naive_bayes import BernoulliNB

rng = np.random.RandomState(1)
X = rng.randint(5, size=(6, 100))
Y = np.array([1, 2, 3, 4, 4, 5])

clf = BernoulliNB(force_alpha=True)
clf.fit(X, Y)
print(clf.predict(X[2:3]))

今回は特定のデータに対する予測ラベルではなく、予測確率の計算方法を確認します
sklearnではナイーブベイズの基本クラス_BaseNBの以下関数で計算されています。

def predict_proba(self, X):
    # 入力に対するクラスごとの確率を返す
    return np.exp(self.predict_log_proba(X))

参考コードのデータで確認すると、確かに5つのラベルに対する確率が得られました。
この関数を追っていけば、確率の計算やその正規化の処理がありそうです。
ではまず、returnにある指数関数の入力部分について、

predict_log_proba()
def predict_log_proba(self, X):
    #一部省略
    jll = self._joint_log_likelihood(X)
    # normalize by P(x) = P(f_1, ..., f_n)
    log_prob_x = logsumexp(jll, axis=1)
    return jll - np.atleast_2d(log_prob_x).T

ここで正規化が行われていました。
self._joint_log_likelihood(X)で各ラベルの対数尤度を計算し、
scipy.special.logsumexpで正規化の処理をしています。
具体的には以下の計算を実行しているようです。(参考ページ)

\begin{align}
\log(\Sigma^{N}_{i=1}e^{a_{i}}) &= \log(e^{max_{i} a_{i}}\Sigma^{N}_{i=1}e^{a_{i}-max_{i} a_{i}})\\
&= max_{i} a_{i} + \log(\Sigma^{N}_{i=1}e^{a_{i}-max_{i} a_{i}})
\end{align}

この形で計算することで、尤度計算で起きがちなoverflow問題を解消している模様
この段階で当初の迷った部分は大方解消されましたが、実装の間違いが無いかも含め、更に進んでいきます。

今回の調査対象であるBernoulliNBの周辺尤度は以下の関数で計算されています。
(基本クラス_BaseNBで抽象基底クラスとして定義されているため、クラスごとで中身が異なります)

_joint_log_likelihood()
def _joint_log_likelihood(self, X):
    # 一部省略
    neg_prob = np.log(1 - np.exp(self.feature_log_prob_))
    # Compute  neg_prob · (1 - X).T  as  ∑neg_prob - X · neg_prob
    jll = safe_sparse_dot(X, (self.feature_log_prob_ - neg_prob).T)
    jll += self.class_log_prior_ + neg_prob.sum(axis=1)
    return jll

self.feature_log_prob_は訓練時(_BaseDiscreteNB.fit())で計算されたものを使用。
ということで_BaseDiscreteNB.fit()の中身を確認します。

_BaseDiscreteNB.fit()
def fit(self, X, y, sample_weight=None):
    # 一部省略
    class_prior = self.class_prior

    # Count raw events from data before updating the class log prior
    # and feature log probas
    n_classes = Y.shape[1]
    self._init_counters(n_classes, n_features)
    self._count(X, Y)
    alpha = self._check_alpha()
    self._update_feature_log_prob(alpha)
    self._update_class_log_prior(class_prior=class_prior)
    return self

まず、self._init_counters()ではラベル側とラベルごとの特徴量のカウントを取る行列を作成。
途中でY(n_samples, n_classes)のone-hot行列に変換されています。(Xも同様に変換)
そうすることで、ラベルごとの特徴量のカウントがY.T @ Xで完結します。(ラベル側のカウントもY.sum(axis=0)で終了)
alphaについてはスムージング用のもの、指定しない限りはデフォルトの1となります。
_update_feature_log_prob()で最終的な$P(Y|X)$(クラスごとの特徴量の確率)を算出。

BernoulliNB._update_feature_log_prob()
def _update_feature_log_prob(self, alpha):
    """Apply smoothing to raw counts and recompute log probabilities"""
    smoothed_fc = self.feature_count_ + alpha
    smoothed_cc = self.class_count_ + alpha * 2

    self.feature_log_prob_ = np.log(smoothed_fc) - np.log(smoothed_cc.reshape(-1, 1))

_update_class_log_prior()は名前の通り、ラベル側の確率を計算する関数。
特に指定しない限りは、与えたデータを元に計算される模様。(elifの部分)

_BaseDiscreteNB._update_class_log_prior()
def _update_class_log_prior(self, class_prior=None):
    # 一部省略
    n_classes = len(self.classes_)
    if class_prior is not None:
        if len(class_prior) != n_classes:
            raise ValueError("Number of priors must match number of classes.")
        self.class_log_prior_ = np.log(class_prior)
    elif self.fit_prior:
        # 一部省略
            log_class_count = np.log(self.class_count_)
        # empirical prior, with sample_weight taken into account
        self.class_log_prior_ = log_class_count - np.log(self.class_count_.sum())
    else:
        self.class_log_prior_ = np.full(n_classes, -np.log(n_classes))

訓練部分を確認できたので、再び下記の周辺尤度の計算部分に戻ります。

_joint_log_likelihood()
def _joint_log_likelihood(self, X):
    # 一部省略
    neg_prob = np.log(1 - np.exp(self.feature_log_prob_))
    # Compute  neg_prob · (1 - X).T  as  ∑neg_prob - X · neg_prob
    jll = safe_sparse_dot(X, (self.feature_log_prob_ - neg_prob).T)
    jll += self.class_log_prior_ + neg_prob.sum(axis=1)
    return jll

途中のコメント部分に書いている通り、計算方法の工夫をしているのみ。
これでサンプルコードの書いている部分の中身の確認は完了です。
他のナイーブベイズについても、一部の関数が変わるだけで基本的な部分は同じ(はず)。

この内容を元に、自分の実装も修正出来そうです。

最後に

今回もライブラリの中身を確認する記事となりました。
文献を元に実装した際に、どうしても不安が残るので確認する面もありますが、
よく使われているライブラリの実装上の工夫も知れるので、たまにやっていくことになるのかと思います。

1
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
1