3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Kolmogorov-Arnold Networks (KAN)をGANで使ってみようという試み

Last updated at Posted at 2024-07-29

はじめに

Kolmogorov-Arnold Network(KAN)は、MITから発表されてからここ数ヶ月の間、懐疑的、または肯定的な目を向けられている。今まで私たちが使用してきたMulti-layer perceptron(MLP)と類似する点は見られるが異なる形をしており、論文内では自然科学の再発見に使える可能性が示唆されていた。すでにいくつか、深層生成モデルに組み込もうという流れからkan-gptなどが開発されているのだが、今回は基礎的なモデルで検証を行おうということから、そのKANをGANで実装してみた場合どうなるのか検証した。ここでは素朴なモデルを検証している。

目次

1.Kolmogorov-Arnold Networks(KAN)とは?
2.敵対的生成ネットワーク(GAN)とは?
3.今回の目的
4.動作環境、条件など
5.pytorchとpykanでの調査
6.efficient-kan
7.まとめと展望

※ミスがあればご指摘願います。

Kolmogorov-Arnold Networks(KAN)

kolmogorv-Arnold Networksというのは、2024年4月30日にプレプリント公開サイトarXivに投稿された論文[KAN:Kolmogorv-Arnold Networks]にて提案された従来のMLPとは異なる新たなニューラルネットワーク構造である。コルモゴロフ・アーノルド表現定理(Kolmogorov-Arnold representation theorem)に基づいて設計されていて、従来のMLPのような学習を行わない。
KANにはいくつかの利点があり、大きく分けて二つ、解釈性の高さとパラメータを抑えられるということだ。KANは、非線形変換のみ行い、それがエッジ上で行われるという構造をしている。
pykanライブラリが提供されており、現状多くの方向で試すことが可能だ。
筆者なりの解釈だが、学習するときの線形変換と非線形変換のポイントが大きく異なることで、利点が生まれているというものだと認識している。

ここでは、深く説明するつもりがないため、次のサイトを参照していただきたい。筆者は次のサイトから上記の言葉を引用しているので、筆者が何を言いたかったのかがわかると思う。

参照、引用元
論文 https://arxiv.org/abs/2404.19756

KAN (Kolmogorov-Arnold Networks) の利点を単純化して理解する:https://dalab.jp/archives/journal/advantages-of-kolmogorov-arnold-networks/

敵対的生成ネットワーク(GAN)

Generator(生成器)とDiscriminator(判別器)の二つのモデルを作成し、それを競わせることによって学習をする特殊なものだ。Generatorは本物に近い画像を生成し、Discrininatorをだますことを目的にしていて、Discrininatorは本物と偽物を区別し、Generatorに騙されないように学習するというもの。互いに競争関係にあることから「敵対的」という表現をされている。
最近は、transformerができたことや、GANに欠点が多すぎることからスポットライトが当てられることが減っているイメージが筆者にはある。しかし、今まで様々なことに対して応用をされてきたことから、筆者としては期待を持っている。
詳細な情報については、wikiなどの方が詳しく乗っていると思うのでそちらを読むことを推奨する。
wiki:

敵対的生成ネットワーク (GAN) - MATLAB & Simulink:

今回の目的

今回の目的は、KANがGANに適するものなのかを調べることが一番の目的だ。他に挙げるとするならば、KANは解釈性に優れていて科学の再発見に使われることを目的としていたことから、画像や音楽を別の観点から評価することに使えるのではないか?という期待があることだろう。

動作環境、条件など

動作環境

非常に重たい動作になった時のためにgoogle colabを使用することにした。いろいろと導入できるライブラリは限られているが、誰もが手軽に試せる環境で再現性が高いことからも最適だと言えるだろう。
※断じて筆者のpcのスペックが低いからではない。

pytorchとpykanでの調査:
pykanは多くの発展がされているが、発展したものを使用せずに基礎に対して、忠実にいくために論文やgithubのリポジトリにも存在しているpykanとpytorchというオーソドックスな形でプログラムを組んでいく。
pykanは非常に計算がおおいのでシンプルなモデルにする

pykanにはLBFGSというoptimizerが用意されているがこれを使わない。理由としては、使用するメモリが多すぎてうまく動かないからだ。これは、今現在は見られなくなってしまっているがtutorialという文書に書いてあった。

pytorchとpykanでの調査

ライブラリの環境導入

ライブラリ
pip install pykan
pip install pytorch
pip install torchvision

ライブラリのインポート

import
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
import torchvision.transforms as tfs
import matplotlib.pyplot as plt
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

データのロード

data_loard
def preprocess_img(x):
    x = transforms.ToTensor()(x)
    return (x - 0.5) / 0.5
def deprocess_img(x):
    return (x + 1.0) / 2.0

class ChunkSampler(sampler.Sampler):
    def __init__(self, num_samples, start=0):
        self.num_samples = num_samples
        self.start = start

    def __iter__(self):
        return iter(range(self.start, self.start + self.num_samples))

    def __len__(self):
        return self.num_samples
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

train_set = MNIST('./data', train=True, download=True, transform=transform)
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
val_set = MNIST('./data', train=True, download=True, transform=transform)
val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))

imgs = deprocess_img(next(iter(train_data))[0].view(batch_size, 784)).numpy().squeeze()

ただ単に学習を行うのでは難しい場合がある(収束失敗を今回は想定)故に次の二つのサイトより参考にして上記のようにsamplingなどを行った。

KAN_GAN(ここではgenerator:MLP、discrininator:KANの畳み込みが使われていた)

How to Train a GAN? Tips and tricks to make GANs work

Deep Learning for Computer Visionより引用

生成器と判別器

生成器のプログラム

Generator
class Generator(nn.Module):
  def __init__(self, latent_dim):
    super(Generator, self).__init__()
    self.kan = KAN([100,256,512,1024,784],seed=42).speed()
    self.tanh = nn.Tanh()

  def forward(self, x):
    x = self.kan(x)
    x = self.tanh(x)
    return x.view(x.size(0), 1, 28, 28)

判別器のプログラム

Discriminator
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.kan = KAN([784,1024,512,1],seed=42).speed()

  def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.kan(x)

モデルのインスタンス化

instans
generator = Generator().to(device)
discriminator = Discriminator().to(device)

オプティマイザの定義

optimizer
optimizer_G = optim.Adam(generator.parameters(), lr=3e-4, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=3e-4, betas=(0.5, 0.999))
bce_loss = nn.BCEWithLogitsLoss()

トレーニングループ

判別器のトレーニング

discriminator_training
z = (torch.rand(batch_size, NOISE_DIM) - 0.5) / 0.5
z = z.to(device)
fake_img = generator(z)
logits_real = discriminator(real)
logits_fake = discriminator(fake)

size = logits_real.shape[0]
true_labels = torch.ones(size, 1).float().to(device)
fake_labels = torch.zeros(size, 1).float().to(device)
d_loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, fake_labels)

optimizer_D.zero_grad()
d_loss.backward()
optimizer_D.step()

生成器のトレーニング

generator_training
z = (torch.rand(batch_size, NOISE_DIM) - 0.5) / 0.5
z = z.to(device)
fake_img = generator(z)
logits_fake = discriminator(fake)

size = logits_fake.shape[0]
true_labels = torch.ones(size, 1).float().to(device)
g_loss = bce_loss(logits_fake, true_labels)

optimizer_G.zero_grad()
g_loss.backward()
optimizer_G.step()

全体的なプログラム

all_training
for epoch in range(epochs):
    for i, (imgs, _) in enumerate(train_data):
        batch_size = imgs.shape[0]
        real = imgs.to(device)

        # Discriminator_training
        z = (torch.rand(batch_size, NOISE_DIM) - 0.5) / 0.5
        z = z.to(device)
        fake_img = generator(z)
        logits_real = discriminator(real)
        logits_fake = discriminator(fake)
        
        size = logits_real.shape[0]
        true_labels = torch.ones(size, 1).float().to(device)
        fake_labels = torch.zeros(size, 1).float().to(device)
        d_loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, fake_labels)
        
        optimizer_D.zero_grad()
        d_loss.backward()
        optimizer_D.step()

        # Generator_training
        z = (torch.rand(batch_size, NOISE_DIM) - 0.5) / 0.5
        z = z.to(device)
        fake_img = generator(z)
        logits_fake = discriminator(fake)

        size = logits_fake.shape[0]
        true_labels = torch.ones(size, 1).float().to(device)
        g_loss = bce_loss(logits_fake, true_labels)

        optimizer_G.zero_grad()
        g_loss.backward()
        optimizer_G.step()

ここは、先ほどのKAN_GANからラベルなどのポイントを引用しているのでほぼ同じプログラムになっている。申しわけない。
画像出力

plot
with torch.no_grad():
    sample_z = torch.randn(16, NOISE_DIM).to(device)  # 4つの画像を生成
    generated_imgs = generator(sample_z).cpu()
    grid = torchvision.utils.make_grid(generated_imgs, nrow=4, normalize=True)  # 2行で表示
    plt.figure(figsize=(4, 4))
    plt.imshow(grid.permute(1, 2, 0))
    plt.show()

といった感じに行ってみた。一応これは、KANとKANを使用している。
結論から言うと、これを使おうというのは、実用性に欠けている。"時間がかかりすぎる"のだ。
Q:1時間かけてどれだけ進んだか?
A:"A100:0epoch"
どうすればいいのかわからない……。という感じだ。
次のことを試してみることにした。
7月13日にpykanの内部構造が変更されてMultKANに変更された影響で作られた"スピードモデル"だ。
使い方は次の通り。

pykanのtutorialにあるspeedのexampleを引用
モデルに対して、.speed()と付け加えるだけである。ただし、ここでモデルの表示はできなくなる。速さの代わりに解釈性が失われるのだ。
ここで筆者は、もとのプログラムでどのようなモデルの形状をしているのかを確認することにした。この時点で解釈性が悪いのであれば、ここでスピードモードにしても問題はないだろう。元のモデルを表示する時間が1時間を超え、そして、51GBもあったランタイムをクラッシュさせたので次のようにプログラムを組んだ。

これがその画像なのだが……。
KAN-GAN.png
「解釈性が上がった」とは?これならば、MLPで作られたモデルの方が解釈性が高いように思える。
本末転倒な気がするのだが、今回はKANがGANで使える存在なのかを調査することが目的なのでそのまま進める。
次のようにプログラムを変更した。

generator
class Generator(nn.Module):
  def __init__(self, latent_dim):
    super(Generator, self).__init__()
    self.kan = KAN([100,256,512,1024,784],seed=42).speed()
    self.tanh = nn.Tanh()

  def forward(self, x):
    x = self.kan(x)
    x = self.tanh(x)
    return x.view(x.size(0), 1, 28, 28)
discriminator
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    self.kan = KAN([784,1024,512,1],seed=42).speed()

  def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.kan(x)

ここで大きな問題が発生する。pykanはMultiKANに変更されてからまだ時間がそこまで経っていないため問題が多くあるのだ。speedを用いるとcudaつまり、gpuが使えなくなってしまう。それでも早くはなったのでKANとKANの組み合わせを使ってみた。ハイパラメータはある程度調整してみた。結果としては、先ほどのものが筆者の実験ではすべてのランダムな入力に対して同等に画像の出力を行うことができた。
次のものが結果の画像だ。
image.png
これは、約10時間、10epochで生成された画像だ。数字とは、お世辞にも言い難い。これ以上epoch数を伸ばせばランタイムの接続を切られてしまうことがあるので、speedを用いたKANでもさすがに限界を感じた。(筆者のものは収束するポイントまで行きついていない可能性もあるので、時間がある人は学習率を変更して行ってほしい。)

efficient-kan

この題名でわかる通り、KANをefficient-kanに変更したのだ。先ほどまでの実験の話で分かってもらえたと思うが、KANは時間がかかりすぎるモデルだ。また、新たなモデルであるということから最適化がされきっていないのだ。故に最適化に向かうため様々なものが開発されている。例を挙げるとするならば、今回使用するefficient-kanやfast-kan、ChebyKANなどだ。どれも元のKANの一番のボトルネックとなっているBスプライン構造を別のものに置き換えて速度を上げようというものである。だが、全てのモデルが利点であった解釈性を大きく下げてしまうという問題を抱えている。問題はいずれも同じで、他のものに置き換わっているだけで、先ほどのspeedモードと変わらないというわけだ。

efficient-kanって何?

筆者のバイアスや感覚が入ってしまう可能性があるのでリポジトリをgoogle翻訳で直訳しただけのものを貼っておく。
元の実装のパフォーマンスの問題は、主に、異​​なる活性化関数を実行するためにすべての中間変数を展開する必要があることに起因します。in_features入力とout_features出力を持つレイヤーの場合、元の実装では、活性化関数を実行するために、入力を形状を持つテンソルに展開する必要があります(batch_size, out_features, in_features)。ただし、すべての活性化関数は、B スプラインである固定された基底関数のセットの線形結合です。そのため、計算を、入力を異なる基底関数で活性化してから線形に結合するように再定式化できます。この再定式化により、メモリ コストが大幅に削減され、計算が簡単な行列乗算になり、順方向パスと逆方向パスの両方で自然に機能します。

問題は、KAN の解釈可能性にとって重要であると主張されているスパース化(batch_size, out_features, in_features)にあります。著者らは、入力サンプルに定義された L1 正則化を提案しましたが、これはテンソルに対する非線形演算を必要とするため、再定式化と互換性がありません。私は代わりに、L1 正則化を重みに対する L1 正則化に置き換えました。これはニューラル ネットワークでより一般的であり、再定式化と互換性があります。著者の実装には、論文で説明されている正則化に加えて、実際にこの種の正則化も含まれているため、役立つと思います。これを検証するには、さらに実験が必要ですが、少なくとも、効率が求められる場合、元のアプローチは実行不可能です。

もう 1 つの違いは、学習可能な活性化関数 (B スプライン) の他に、元の実装には各活性化関数の学習可能なスケールも含まれていることです。この機能を含めるようenable_standalone_scale_splineにデフォルトで設定されているオプションを提供しましたTrue。無効にするとモデルの効率は上がりますが、結果に悪影響を与える可能性があります。さらに実験が必要です。

次より、直訳して引用

efficient-kanでの実験

ライブラリのインポート

import torch
import torch.nn as nn
import torch.optim as optim
from efficientkan import KAN
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
import torchvision.transforms as tfs
import matplotlib.pyplot as plt
import torchvision
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

この先のプログラムは、モデルの設計以外は変わっていないため変更したモデルのプログラムのみ書いておく

generator:KANとdiscrinimator:KANのセット

生成器のプログラム

generator_kan
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.kan = KAN([100,256,512,1024,784])
        self.Tanh = nn.Tanh()

    def forward(self, x):
        x = self.kan(x)
        return x.view(x.size(0), 1, 28, 28)

判別器のプログラム

discriminator_kan
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.kan = KAN([784,1024,512,1])

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.kan(x)

生成された画像

学習率の変更により精度の高い画像が生成される可能性が見られたため

再調査中

皆様に誤った情報を発信してしまったことを深くお詫び申し上げます。
下のものも同様です。少々お待ちください。

generator:MLPとdiscrinimator:KANのセット

生成器のプログラム

generator_mlp
class Generator(nn.Module):
    def __init__(self, noise_dim=NOISE_DIM):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(noise_dim, 512),
            nn.ReLU(True),
            nn.BatchNorm1d(512),
            nn.Linear(512, 1024),
            nn.ReLU(True),
            nn.BatchNorm1d(1024),
            nn.Linear(1024, 784),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.model(x)
        return x.view(x.size(0), 1, 28, 28)

一応バッチ最適化をしている。これは収束失敗を避けるためだ。
判別器のプログラム

discriminator_kan
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.kan = KAN([784, 1024, 1])

    def forward(self, x):
        x = x.view(x.size(0), -1)
        return self.kan(x)

生成された画像
image.png
image.png
かなり早めにこのような画像になって収束した。数字らしいと感じられるのが生成されたのは今回設定したセットだけだった。モード崩壊(1のみが生成)がこの後に起こるが、これはかなり強いスプライン力を持つKANの特性からくるものだと考えられる。
総合的に見て、GANにKANを用いるのは困難なようだ。GANでの生成にあまり適していない様に見える。

まとめと考察、今後の展望

まとめ

KANはMLPの良き制御者(discriminator)としてGANに使えるかもしれないが、生成側としては良きmodelとは言えないと考えられるだろう。

考察

KANとKANの組み合わせの良くないところは、互いに互いを学習しすぎる過学習がどれだけ小さなものでも発生してしまうということでKANとKANの組み合わせはGeneratorが弱すぎる、というかdiscriminatorが強すぎたのでdiscriminatorの学習率を10分の1まで落とすことにより収束することが確認された。
また、KANとMLPの組み合わせのみがうまく行ったのは、過学習と言っていいほど学習するKANにたいして、MLPがKANほどの強い学習を行わなかったからなのではないかと推測される。ここからは、筆者が参考にしていたgithubに存在するより美しい画像を生成できる畳み込み層を保持したものについて、畳み込み層が用意されることによってKANの学習が安定したからだと考察している。畳み込み層は、より高度でフィットしやすかったのではないかと考察している。また、筆者が独自でepoch数やハイパラを変更していたときに起こった話なのだが、筆者が生成した画像でも少し感じられたと思う。そう、1のモード崩壊だ。これが幾度か散見された。よって筆者としての総合的な意見としては、KANはGANに向かないモデルである。

展望

今回の結果からdisciriminatorとして、制御を行おうという点でKANは使えるだろう。DCGANやMuseGANに流用することもよいかもしれない。
また、KANがとても探索に優れたモデルであることは、disciriminatorの能力を抑えなければうまく動かないこともあるほどであったことからよくわかった。故に、すでに開発されているkanformer(TransfromerをKANにしたもの)などは強い有用性を持つのではないか?ということである。

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?