5
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Markov Logic Network (2)

Posted at

こちらの続き。

Weight Learning

前回は各論理式$f_i$に重み$w_i$が与えられているという状況でのクエリについてやったが、現実的には$w_i$は学習の対象となる。

重み$\mathbf{w}=(w_1,\ldots,w_m)$が与えられた時の確率分布は

P(X=x|\mathbf{w})=\frac{1}{Z}\exp\left(\sum_iw_in_i(x)\right)\\
Z = \sum_x\exp\left(\sum_iw_in_i(x)\right)

だったので、対数尤度は

\begin{align}
\log L(\mathbf{w}|X=x) &= \log P(X=x|\mathbf{w}) \\
                       &= \sum_iw_in_i(x) - \log Z
\end{align}

となる。これを最大化したいので勾配を計算すると

\begin{align}
\frac{\partial}{\partial w_i}\log L(\mathbf{w}|X=x) &= n_i(x) - \sum_x n_i(x)\frac{1}{Z}\exp\left(\sum_iw_in_i(x)\right) \\
&= n_i(x) - \sum_x n_i(x) P(X=x|\mathbf{w}) \\
&= n_i(x) - E_x[n_i(x)]
\end{align}

となる。つまり、与えられた世界$x$で論理式$f_i$が成立している回数($n_i(x)$)が、期待される回数($E_x[n_i(x)]$)と等しくなる(つまり勾配が$0$となる)方向に$w_i$を進めれば良いという事。
ここで計算した勾配を使ってBFGS法などで尤度の最大化を行えば良い。再急降下法も試したが、収束速度が非常に遅かったのでBFGSもしくはL-BFGSなど二次の項まで見る方法を使う。

実際には$n_i(x)$を数えるのは大変だし、$E_x[n_i(x)]$は全ての可能な世界について和であるのでもっと困難であるので、モンテカルロ法等が必要となる。また、グラフで表現された条件付き独立性も活用出来る。これは後でやる。

計算例

前回と同じモデル

$F_1: \neg\rm{Smokes}(x)\vee\rm{Cancer}(x): w_1$
$F_2: \neg\rm{Friends}(x,y)\vee \neg\rm{Smokes}(x)\vee\rm{Smokes}(y): w_2$
$F_3: \neg\rm{Friends}(x,y)\vee \rm{Smokes}(x)\vee\neg\rm{Smokes}(y): w_3$

で$C=\{A,B\}$かつ世界$x$が

$\rm{Smokes}(A)=True, \rm{Smokes}(B)=True$
$\rm{Friends}(A,A)=\rm{Friends}(B,B)=False, \rm{Friends}(A,B)=\rm{Friends}(B,A)=True$
$\rm{Cancer}(A)=True, \rm{Cancer}(B)=False$

の場合の$w_1,w_2,w_3$を、とりあえず定義通りに計算してみると以下のようになる。

import pandas as pd
import numpy as np
from scipy.misc import logsumexp
from scipy.optimize import fmin_bfgs
from itertools import product

const = ['A', 'B']
preds = [('Smokes', 1), ('Cancer', 1), ('Friends', 2)]  # Predicate and arity

ground_atoms = [
    (p, *args)
    for p, arity in preds
    for args in product(const, repeat=arity)
    ]

print('=== Ground Atoms ===')
print(ground_atoms)

formulas = [
    # (atom, negation, arity, weight)
    ([('Smokes', (0,)), ('Cancer', (0,))], [1, 0], 1),
    ([('Friends', (0,1)), ('Smokes', (0,)), ('Smokes', (1,))], [1, 0, 1], 2),
    ([('Friends', (0,1)), ('Smokes', (0,)), ('Smokes', (1,))], [1, 1, 0], 2),
    ]

ground_formulas = []
for i, (clauses, neg, arity) in enumerate(formulas):
    for args in product(const, repeat=arity):
        ground_formula = [
            (p, *map(lambda i: args[i], v))
            for p, v in clauses
            ]
        ground_formulas.append((ground_formula, neg, i))

print('=== Ground Formulas ===')
print(ground_formulas)

# Generate all configurations
X = pd.DataFrame(columns=ground_atoms, data=list(product([1, 0], repeat=len(ground_atoms))))

def mloglike(w, world_idx):
    # Compute n_i(x) and sum_i(w_i*n_i(x))
    n = np.zeros((len(X), len(w)))
    S = np.zeros(len(X))
    for f, neg, i in ground_formulas:
        v = np.logical_xor(X[f], neg).any(1)
        S += w[i] * v
        n[:, i] += v

    # Compute joint probabilities
    logP = S - logsumexp(S)
    return -logP[world_idx]

def grad_mloglike(w, world_idx):
    # Compute n_i(x) and sum_i(w_i*n_i(x))
    n = np.zeros((len(X), len(w)))
    S = np.zeros(len(X))
    for f, neg, i in ground_formulas:
        v = np.logical_xor(X[f], neg).any(1)
        S += w[i] * v
        n[:, i] += v

    # Compute joint probabilities
    logP = S - logsumexp(S)

    # Compute gradient
    g = (n[world_idx] - np.exp(logsumexp(np.log(n) + logP.reshape(-1,1), axis=0)))
    return -g

print('=== Training ===')
world = [(('Smokes', 'A'), True),
         (('Smokes', 'B'), True),
         (('Friends', 'A', 'A'), False),
         (('Friends', 'B', 'B'), False),
         (('Friends', 'A', 'B'), True),
         (('Friends', 'B', 'A'), True),
         (('Cancer', 'A'), True),
         (('Cancer', 'B'), False)
         ]
world = np.array(world)
world_idx = np.where((X[world[:,0]] == world[:,1]).all(1))[0][0]

w0 = np.zeros(len(formulas))
w = fmin_bfgs(mloglike, w0, fprime=grad_mloglike, args=(world_idx,))
print(w)

結果は

w_1 = -0.8958807, w_2=w_3=11.96577936

となる。

5
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
5
5

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?