199
205

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

本家Examplesで知る、新たなニューラルネットワーク『KAN』の威力

Last updated at Posted at 2024-05-08

文中の図は理がない限り、原論文あるいはドキュメントからの引用です。

KANってなに?

KAN(Kolmogorov-Arnold Networks) は2024年4月30日にプレプリント公開サイトarXivに投稿された論文

にて提案された従来のMLPとは異なる新たなニューラルネットワーク構造です。1

コルモゴロフ・アーノルド表現定理(Kolmogorov-Arnold representation theorem)に基づいて設計されており、非線形な活性化関数そのものを直接学習することから、パラメータ効率が良く、学習結果の解釈可能性が高いことが特徴です。

image.png
上の画像は$x, y$を入力として$\exp(\sin(\pi x)+y^2)$を正解として学習させるケースを示した図ですが、元の関数の関係性がそのまま活性化関数の形状として学習され現れていることが読み取れます。

コルモゴロフ・アーノルド表現定理

$$f(x) = f(x_1,...,x_n)=\sum_{q=1}^{2n+1}\Phi_q(\sum_{p=1}^n \phi_{q,p}(x_p))$$

数学的に厳密な解説は私にはできないので簡単に説明すると、任意の連続な多変数関数は有限個の連続な1変数関数の合成で表現できるという定理だそうです。

本家ドキュメント翻訳:
$f$は有界領域上の多変量連続関数である場合、単一変数の連続関数と加算の二項演算の有限合成として記述することができます。

難しく考えずに式を見ると、1変数関数$\phi_{q,p}(x_p)$の和の関数$\Phi_q(\cdot)$を更に複数足し合わせているだけの、それほど複雑ではない構造です。
この式から分かるようにコルモゴロフ・アーノルド表現は2層の入れ子構造になっています。

筆者はこの計算を次のような行列様の表現で示しています。

f(x)={\bf \Phi}_{\rm out}\circ{\bf \Phi}_{\rm in}\circ {\bf x}
\begin{split}{\bf \Phi}_{\rm in}= \begin{pmatrix} \phi_{1,1}(\cdot) & \cdots & \phi_{1,n}(\cdot) \\ \vdots & & \vdots \\ \phi_{2n+1,1}(\cdot) & \cdots & \phi_{2n+1,n}(\cdot) \end{pmatrix},\quad {\bf \Phi}_{\rm out}=\begin{pmatrix} \Phi_1(\cdot) & \cdots & \Phi_{2n+1}(\cdot)\end{pmatrix}\end{split}

式と見比べると、行列積の位置対応で右側の行列の要素が左側の行列の関数の入力になるという意味みたいですね。(私が無知なだけで一般的な表現?)

Kolmogorov-Arnold Network

前述の2層構造では表現力が限定される(特に実装の際に利用可能な関数では)ことから、2層以上の層を積み重ねて、より複雑なモデルを構築することを考えます。

その際先程の行列表現が役立ち、モデルを次のように表現することができます。

{\rm KAN}({\bf x})={\bf \Phi}_{L-1}\circ\cdots \circ{\bf \Phi}_1\circ{\bf \Phi}_0\circ {\bf x}

ここで一般化した層$\Phi$をKolmogorov-Arnoldレイヤーと呼びます。

\begin{split}{\bf \Phi}= \begin{pmatrix} \phi_{1,1}(\cdot) & \cdots & \phi_{1,n_{\rm in}}(\cdot) \\ \vdots & & \vdots \\ \phi_{n_{\rm out},1}(\cdot) & \cdots & \phi_{n_{\rm out},n_{\rm in}}(\cdot) \end{pmatrix}\end{split}

個々の関数$\phi_{i,j}$は後述の方法でモデル化するとして、ネットワーク全体の計算としては行列積の位置関係で前層の出力を関数に入力するだけの単純な仕組みで進めることができます。

個々の活性化関数のモデル化

筆者は個々の活性化関数$\phi_{i,j}$をB-スプライン曲線で表現することを提案しています。
image.png
B-スプライン曲線は制御点を中心とした釣鐘型の区分多項式の和で表現され、各係数が等しい場合区間内で定数となる特徴があります。

B-スプライン曲線は柔軟に様々な形状の曲線を表現することができますが、有限個の制御点では表現できる範囲が限られてしまう問題があります。
この問題に対して実装では、随時入力値の範囲に応じて制御点の位置を更新し、新たな制御点の係数を更新前の形状をなるべく再現できるように最小二乗法で決定するという手段が取られています。

擬似コード
x = 評価点
prev_spline_coeffs = spline_coeffs
prev_spline_bases = b_spline(x, grid)
prev_spline = sum(prev_spline_coeffs*prev_spline_bases)
grid = new_grid
new_spline_bases = b_spline(x, new_grid)
spline_coeffs = least_square_fit(new_spline_bases, prev_spline)

本家Examples

ここからはタイトルの通り本家ドキュメントのいくつかのExampleを通してKANの威力を実感していきます。

Example 3: Classification(分類)

よく見かける三日月が2つ並んだようなデータセットを用意します。

データセットの準備
from kan import KAN
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
import torch
import numpy as np

dataset = {}
train_input, train_label = make_moons(n_samples=1000, shuffle=True, noise=0.1, random_state=None)
test_input, test_label = make_moons(n_samples=1000, shuffle=True, noise=0.1, random_state=None)

dataset['train_input'] = torch.from_numpy(train_input)
dataset['test_input'] = torch.from_numpy(test_input)
dataset['train_label'] = torch.from_numpy(train_label[:,None])
dataset['test_label'] = torch.from_numpy(test_label[:,None])

X = dataset['train_input']
y = dataset['train_label']
plt.scatter(X[:,0], X[:,1], c=y[:,0])

まずは回帰問題(出力1次元, MSE Loss)として扱います。

学習
model = KAN(width=[2,1], grid=3, k=3)

def train_acc():
    return torch.mean((torch.round(model(dataset['train_input'])[:,0]) == dataset['train_label'][:,0]).float())

def test_acc():
    return torch.mean((torch.round(model(dataset['test_input'])[:,0]) == dataset['test_label'][:,0]).float())

model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc));
自動数式推定
lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','tan','abs']
model.auto_symbolic(lib=lib)
formula = model.symbolic_formula()[0][0]
formula
fixing (0,0,0) with sin, r2=0.967966050300312
fixing (0,1,0) with tan, r2=0.9801151730516574

$$\displaystyle 0.39 \sin{\left(3.08 x_{1} + 1.56 \right)} - 0.79 \tan{\left(0.94 x_{2} - 3.37 \right)} + 0.51$$

推定した数式の分類精度のテスト
# how accurate is this formula?
def acc(formula, X, y):
    batch = X.shape[0]
    correct = 0
    for i in range(batch):
        correct += np.round(np.array(formula.subs('x_1', X[i,0]).subs('x_2', X[i,1])).astype(np.float64)) == y[i,0]
    return correct/batch

print('train acc of the formula:', acc(formula, dataset['train_input'], dataset['train_label']))
print('test acc of the formula:', acc(formula, dataset['test_input'], dataset['test_label']))
train acc of the formula: tensor(1.)
test acc of the formula: tensor(1.)

駆け足でここまで書いてしまいましたが、軽く解説すると、1行目

model = KAN(width=[2,1], grid=3, k=3)

で、入力2・中間層なし・出力1のモデルを定義しています。
筆者はB-スプライン曲線の制御点の座標をグリッドと呼んでいますが、制御点(grid)は3点、kはB-スプライン曲線の次数で3です。
model.train()で学習を行ったのち、

lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','tan','abs']
model.auto_symbolic(lib=lib)

libで指定した関数の候補との適合度を計算し、式を推定しています。
活性化関数を直接学習しているからこそできる芸当です。

わずか2層、数十のパラメータのモデルでここまで実現可能というのは驚きです。
更にテストではさり気なく分類精度100%出てます。


続いて分類問題(出力2次元, CrossEntropy Loss)として扱います。

データセットの準備
from kan import KAN
import matplotlib.pyplot as plt
from sklearn.datasets import make_moons
import torch
import numpy as np

dataset = {}
train_input, train_label = make_moons(n_samples=1000, shuffle=True, noise=0.1, random_state=None)
test_input, test_label = make_moons(n_samples=1000, shuffle=True, noise=0.1, random_state=None)

dataset['train_input'] = torch.from_numpy(train_input)
dataset['test_input'] = torch.from_numpy(test_input)
dataset['train_label'] = torch.from_numpy(train_label)
dataset['test_label'] = torch.from_numpy(test_label)

X = dataset['train_input']
y = dataset['train_label']
plt.scatter(X[:,0], X[:,1], c=y[:])
学習
model = KAN(width=[2,2], grid=3, k=3)

def train_acc():
    return torch.mean((torch.argmax(model(dataset['train_input']), dim=1) == dataset['train_label']).float())

def test_acc():
    return torch.mean((torch.argmax(model(dataset['test_input']), dim=1) == dataset['test_label']).float())

model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc), loss_fn=torch.nn.CrossEntropyLoss());
自動数式推定
lib = ['x','x^2','x^3','x^4','exp','log','sqrt','tanh','sin','abs']
model.auto_symbolic(lib=lib)
formula1, formula2 = model.symbolic_formula()[0]
formula1
fixing (0,0,0) with sin, r2=0.8303828486153692
fixing (0,0,1) with sin, r2=0.7801497677237067
fixing (0,1,0) with x^3, r2=0.9535787267982471
fixing (0,1,1) with x^3, r2=0.9533594412300308

$$\displaystyle - 3113.07 \left(0.21 - x_{2}\right)^{3} - 807.36 \sin{\left(3.13 x_{1} + 1.42 \right)} - 120.29$$

formula2

$$\displaystyle - 3113.07 \left(0.21 - x_{2}\right)^{3} - 807.36 \sin{\left(3.13 x_{1} + 1.42 \right)} - 120.29$$

推定した数式の分類精度のテスト
# how accurate is this formula?
def acc(formula1, formula2, X, y):
    batch = X.shape[0]
    correct = 0
    for i in range(batch):
        logit1 = np.array(formula1.subs('x_1', X[i,0]).subs('x_2', X[i,1])).astype(np.float64)
        logit2 = np.array(formula2.subs('x_1', X[i,0]).subs('x_2', X[i,1])).astype(np.float64)
        correct += ((logit2 > logit1)*1.) == y[i]
    return correct/batch

print('train acc of the formula:', acc(formula1, formula2, dataset['train_input'], dataset['train_label']))
print('test acc of the formula:', acc(formula1, formula2, dataset['test_input'], dataset['test_label']))
train acc of the formula: tensor(0.9700)
test acc of the formula: tensor(0.9660)

先程との違いは

- model = KAN(width=[2,1], grid=3, k=3)
+ model = KAN(width=[2,2], grid=3, k=3)

で出力を2にしている点や、

- dataset['train_label'] = torch.from_numpy(train_label[:,None])
- dataset['test_label'] = torch.from_numpy(test_label[:,None])
+ dataset['train_label'] = torch.from_numpy(train_label)
+ dataset['test_label'] = torch.from_numpy(test_label)

正解ラベルが1階になっている点、

- model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc));
+ model.train(dataset, opt="LBFGS", steps=20, metrics=(train_acc, test_acc), loss_fn=torch.nn.CrossEntropyLoss());

損失関数にCrossEntropyLossを指定してる点くらいでしょうか。

Example 4: Symbolic Regression(シンボリック回帰)

シンボリック回帰は基本的には先程既に行った数式推定のことを指します(関数同定問題)。
この節はKANの紹介というよりは、pykanライブラリで提供されている機能の説明になってしまうので割愛します。

Example 6: Solving Partial Differential Equation (PDE)(偏微分方程式を解く)

2次元ポアソン方程式
$$\nabla^2 f(x,y) = -2\pi^2{\rm sin}(\pi x){\rm sin}(\pi y)$$
境界条件
$$f(-1,y)=f(1,y)=f(x,-1)=f(x,1)=0$$
を解きます。厳密解は
$$f(x,y)={\rm sin}(\pi x){\rm sin}(\pi y)$$
です。

長めのプログラム
from kan import KAN, LBFGS
import torch
import matplotlib.pyplot as plt
from torch import autograd
from tqdm import tqdm

dim = 2
np_i = 21 # number of interior points (along each dimension)
np_b = 21 # number of boundary points (along each dimension)
ranges = [-1, 1]

model = KAN(width=[2,2,1], grid=5, k=3, grid_eps=1.0, noise_scale_base=0.25)

def batch_jacobian(func, x, create_graph=False):
    # x in shape (Batch, Length)
    def _func_sum(x):
        return func(x).sum(dim=0)
    return autograd.functional.jacobian(_func_sum, x, create_graph=create_graph).permute(1,0,2)

# define solution
sol_fun = lambda x: torch.sin(torch.pi*x[:,[0]])*torch.sin(torch.pi*x[:,[1]])
source_fun = lambda x: -2*torch.pi**2 * torch.sin(torch.pi*x[:,[0]])*torch.sin(torch.pi*x[:,[1]])

# interior
sampling_mode = 'random' # 'radnom' or 'mesh'

x_mesh = torch.linspace(ranges[0],ranges[1],steps=np_i)
y_mesh = torch.linspace(ranges[0],ranges[1],steps=np_i)
X, Y = torch.meshgrid(x_mesh, y_mesh, indexing="ij")
if sampling_mode == 'mesh':
    #mesh
    x_i = torch.stack([X.reshape(-1,), Y.reshape(-1,)]).permute(1,0)
else:
    #random
    x_i = torch.rand((np_i**2,2))*2-1

# boundary, 4 sides
helper = lambda X, Y: torch.stack([X.reshape(-1,), Y.reshape(-1,)]).permute(1,0)
xb1 = helper(X[0], Y[0])
xb2 = helper(X[-1], Y[0])
xb3 = helper(X[:,0], Y[:,0])
xb4 = helper(X[:,0], Y[:,-1])
x_b = torch.cat([xb1, xb2, xb3, xb4], dim=0)

steps = 20
alpha = 0.1
log = 1

def train():
    optimizer = LBFGS(model.parameters(), lr=1, history_size=10, line_search_fn="strong_wolfe", tolerance_grad=1e-32, tolerance_change=1e-32, tolerance_ys=1e-32)

    pbar = tqdm(range(steps), desc='description')

    for _ in pbar:
        def closure():
            global pde_loss, bc_loss
            optimizer.zero_grad()
            # interior loss
            sol = sol_fun(x_i)
            sol_D1_fun = lambda x: batch_jacobian(model, x, create_graph=True)[:,0,:]
            sol_D1 = sol_D1_fun(x_i)
            sol_D2 = batch_jacobian(sol_D1_fun, x_i, create_graph=True)[:,:,:]
            lap = torch.sum(torch.diagonal(sol_D2, dim1=1, dim2=2), dim=1, keepdim=True)
            source = source_fun(x_i)
            pde_loss = torch.mean((lap - source)**2)

            # boundary loss
            bc_true = sol_fun(x_b)
            bc_pred = model(x_b)
            bc_loss = torch.mean((bc_pred-bc_true)**2)

            loss = alpha * pde_loss + bc_loss
            loss.backward()
            return loss

        if _ % 5 == 0 and _ < 50:
            model.update_grid_from_samples(x_i)

        optimizer.step(closure)
        sol = sol_fun(x_i)
        loss = alpha * pde_loss + bc_loss
        l2 = torch.mean((model(x_i) - sol)**2)

        if _ % log == 0:
            pbar.set_description("pde loss: %.2e | bc loss: %.2e | l2: %.2e " % (pde_loss.cpu().detach().numpy(), bc_loss.cpu().detach().numpy(), l2.detach().numpy()))

train()

学習したKANをプロットしてみます。

model.plot(beta=10)

図から1層目は線形、2層目は正弦関数であるように思われるので、活性化関数をそれらに差し替えて追加学習します。

for i in range(2):
    for j in range(2):
        model.fix_symbolic(0,i,j,'x')

for i in range(2):
    model.fix_symbolic(1,i,0,'sin')

train()

推定された式を出力します。

formula, var = model.symbolic_formula(floating_digit=5)
formula[0]

$$\displaystyle 0.5 \sin{\left(3.14159 x_{1} - 3.14159 x_{2} + 7.85398 \right)} - 0.5 \sin{\left(3.14159 x_{1} + 3.14159 x_{2} + 1.5708 \right)}$$

少々分かりづらいですが、最初に示した厳密解と概ね一致する解が得られました。

Example 8: KANs’ Scaling Laws(KANのスケーリング則)

この節ではデータサイズとグリッドサイズを変化させながら学習後の損失の大きさのスケーリングを検証しています。
結果のみ紹介すると、
データサイズが十分であればグリッドサイズ<100の領域で$N^4$のスケーリング則に従って損失が減少するようです(著者の実装では)。

Example 9: Singularity(特異点を含む関数)

この節では
$$f(x,y)=sin(log(x)+log(y)) (x>0,y>0)$$

$$f(x,y)=\sqrt{x^2+y^2}$$
の特異点を持つ関数を学習するタスクについてテストを行っていますが、これらの関数では上手く学習を行えたそうです。
ただし、後者に関してはlossが小さくなるうちにある点でnanになってしまったとのことです。

Example 12: Unsupervised learning(教師なし学習)

$x_1, x_2, x_3, x_4, x_5, x_6$ の6つの変数があるとき、
$(x_1, x_2, x_3)$に$x_3={\rm exp}({\rm sin}(\pi x_1)+x_2^2)$ のような依存関係、
$(x_4,x_5)$に$x_5=x_4^3$ のような依存関係が存在し、$x_6$のみが独立であるとします。

この節では、これらの変数のみからこれら依存関係を検出することが可能かというタスクを検証しています。

これを行うアイデアとして、元のデータと別に変数間のインデックスをシャッフルしたデータを用意し、元のデータのラベルを1、シャッフルしたデータのラベルを0と設定することで分類問題に落とし込んでいます。
層の幅を6(入力)→1(選別用中間層)→1(出力)とし、中間層の活性化関数をガウス関数にすることで、シャッフルしていないデータかつ

\begin{align}
&{\rm exp}({\rm sin}(\pi x_1)+x_2^2)-x_3+(x_4=0)+(x_5=0)=0\\ 
\rightarrow&sin(\pi x_1)+x_2^2-\log(x_3)=0
\end{align}

のような出力でないと出力層が1にならないようにするという手段を取っています。

データセットの準備
from kan import KAN
import torch
import copy


seed = 1

# create dataset


def create_dataset(train_num=500, test_num=500):

    def generate_contrastive(x):
        # positive samples
        batch = x.shape[0]
        x[:,2] = torch.exp(torch.sin(torch.pi*x[:,0])+x[:,1]**2)
        x[:,3] = x[:,4]**3

        # negative samples
        def corrupt(tensor):
            y = copy.deepcopy(tensor)
            for i in range(y.shape[1]):
                y[:,i] = y[:,i][torch.randperm(y.shape[0])]
            return y

        x_cor = corrupt(x)
        x = torch.cat([x, x_cor], dim=0)
        y = torch.cat([torch.ones(batch,), torch.zeros(batch,)], dim=0)[:,None]
        return x, y

    x = torch.rand(train_num, 6) * 2 - 1
    x_train, y_train = generate_contrastive(x)

    x = torch.rand(test_num, 6) * 2 - 1
    x_test, y_test = generate_contrastive(x)

    dataset = {}
    dataset['train_input'] = x_train
    dataset['test_input'] = x_test
    dataset['train_label'] = y_train
    dataset['test_label'] = y_test
    return dataset

dataset = create_dataset()
モデル定義と中間層の活性化関数をガウス関数に固定
model = KAN(width=[6,1,1], grid=3, k=3, seed=seed)
model.fix_symbolic(1,0,0,'gaussian',fit_params_bool=False)
学習・プロット
model.train(dataset, opt="LBFGS", steps=50, lamb=0.002, lamb_entropy=10.0);
model.plot(in_vars=[r'$x_{}$'.format(i) for i in range(1,7)])

別の乱数シードで実行すると↓

一度の実行で1つの依存関係しか検出できない弱点はありますが、それぞれの変数の関数形状まで見出すことができています。

おわりに

ここで登場したKANはせいぜい3層程度のものでしたが、それでもこれだけの柔軟性と表現力を備えていることに遥かな可能性を感じたのは自分だけでしょうか。
現状の実装では速度に課題があるようですが、B-スプラインにこだわらず並列処理などの高速化に対応したアーキテクチャが発明されていけばNNの主流がMLPからKANに置き換わる未来もありえなくないと思います。

一人でも多くの人がこの可能性に関心を持って、開発がより一層加速することを期待しています。

  1. B-スプライン曲線で関数$\phi$を表現しているところを、それ自体MLPで表現してしまえば従来のNNと等価だという意見もあります。

199
205
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
199
205

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?