3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

ニューラルネットワークの可視化ライブラリを作ってみた【PyPIリリース】

Last updated at Posted at 2025-02-12

はじめに

 PyTorchで記述されたニューラルネットワークを以下のように可視化するライブラリtorchLinearVisをPyPIにリリースさせて頂きました。例えば

demo.py
import torch.nn as nn

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.model = nn.Sequential(
            ......
            nn.Linear(5, 5),
            ......
        )
        
    def forward(self, x):
        return self.model(x)        

で記述されるニューラルネットワーク中の全結合層部分(5ユニット→5ユニット)の重みが学習過程でどのように変化しているかを以下のような.htmlファイルを生成してブラウザ上からアニメーションを確認できたりするものになります。特に別個で環境を作らずとも生成された1つのhtmlファイルだけがあれば結果を確認できるところが特徴です。

top2_re.gif

 PyPIリリースに伴う手続きのハードルは思っていたより低かったため、公開したライブラリの概要とPyPIリリースまでの一連の流れを紹介させて頂きます。PyPIへのリンクは https://pypi.org/project/torchLinearViz/ に、ソースコードは https://github.com/guard-mann/torchLinearViz.git になります。

動機

 ユニットレベルでニューラルネットワークを可視化するPyPIライブラリが欲しかった。ユニットレベル(=グラフが膨大)だからこそブラウザでグラフを確認して拡大したり見づらいノードを動かしたりと自由度の高い操作を可能にするようなライブラリが作りたかった。

概要

 今回リリースしたライブラリtorchLinearVisは、ニューラルネットワークの全結合(Linear)層をユニット単位のネットワークとして描画する。他の種類の層はそのままレイヤー単位で描画する。全結合層のネットワークにおける(ノード)は全結合層のユニット、ネットワークの(エッジ)は全結合重みに対応しており、線の太さが重みの絶対値と対応している。

overview.png

何ができるのか?

 最初に以下コマンドでインストールを行う。(https://pypi.org/project/torchLinearViz/ より)

$ pip install torchLinearViz

 torchLinearVizはPyTorchを使った訓練コードの前後に以下のような簡単な構文を差し込むことで、学習終了後に1つの.htmlファイルを生成する。デモのコードを以下に添付しておく。(torchLinearVizの差し込み箇所は❗️のコメントで表示させている)

demoMNIST.py
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
# ❗️ 所望のライブラリ呼出ができることを確認
from torchLinearViz import TorchLinearViz

# 🔹 1. データの前処理
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])

# 🔹 2. データセットの読み込み
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

# 🔹 3. MLP(全結合ニューラルネットワーク)モデルの定義
class MLP(nn.Module):
   def __init__(self):
       super(MLP, self).__init__()
       self.model = nn.Sequential(
           nn.Flatten(),  # 画像(28x28) → 1次元 (784)
           nn.Linear(28*28, 5),  # 入力 784 → 隠れ層 5 
           nn.Linear(5, 5),  # 隠れ層 5 → 5
           nn.Linear(5, 10)  # 出力 10クラス
       )

   def forward(self, x):
       return self.model(x)

# 🔹 4. モデルの作成
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MLP().to(device)

# 🔹 5. 損失関数と最適化手法
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# ❗️ torchLinearViz初期化
torchlinearviz = TorchLinearViz(model)

# 🔹 6. 学習
epochs = 100
for epoch in range(epochs):
   model.train()
   total_loss = 0

   for images, labels in train_loader:
       images, labels = images.to(device), labels.to(device)

       optimizer.zero_grad()
       outputs = model(images)
       loss = criterion(outputs, labels)
       loss.backward()
       optimizer.step()
       total_loss += loss.item()
   
   # ❗️ 重み更新を反映
   torchlinearviz.update(model, images)
   
   print(f"Epoch [{epoch+1}/{epochs}], Loss:{total_loss/len(train_loader):.4f}")

# ❗️ htmlファイル出力
torchlinearviz.end()

生成されたhtmlファイルはブラウザで確認できる。デモのhtmlファイルはこちらに添付してある。描画可能なモードとして、単純に重みの絶対値の大きさをそのままエッジの太さと対応付けて表示させるモード(Value)と、1つ前のエポックとの差分の大きさをエッジの太さと対応付けて表示させるモード(Diff (abs))の2つをSwitch Dataボタンから選択可能である。他にも再生速度を変えられたり重みの絶対値のスケールを調整したり好きな位置、拡大度合いを調節することもできる。

top7.gif

好きな位置・拡大度合いを調節してアニメーションを再生する他にも、重なって見づらいノードを好きに動かしてアニメーションを再生することもできる。これはノード数が大きいグラフの画面が潰れて見づらい点を解消するための機能である。

top8.gif

さらにスタイリッシュなUIを好む方向けにカラースキームとしてLight(デフォルト)とDarkを用意してある。

スクリーンショット 2025-02-11 14.39.57.png

これらの添付gifはqiitaに上げる過程でデータ量の問題で遅く粗く見えているが実際はもう少しマシである(こちらのREADMEにも同じようなgifを添付してある)。

PyPIリリースまでの手順

作業は以下の様に工程1 ~ 工程5までの5段階で順番に進めた。

工程1. PyTorchのモデルからグラフ情報を抽出・保存するpythonコードを作成

 今回はpythonでPyTorchのhookを用いてレイヤーの接続関係を取り出した。

analyse_graph.py
for name, layer in model.named_modules():
   layer_names[layer] = name  # レイヤー名を保存
   layer.register_forward_hook(hook_fn)  # フックを登録

 特にユニット数が大きい場合 数千 x 数千 といった大規模のネットワークとなり得るため、1つの全結合層を表現可能な最大ノード数MAXNODEを決め、それより大きい場合はユニット数がMAXNODEとなるように等間隔でサンプリングする仕様にした。例えばMAXNODEを25として、ある全結合層への入力ユニット数が100だった場合、UNIT_4, UNIT_8, ... , UNIT_100のようにユニットを間引いて取ってくる仕様にしている。生成されたJsonファイルはこんな感じのフォーマットになる。

graphInfo.json
{
    "nodes":[
        {
            "data": {
                "id": "model.0",
                "type": "Flatten"
            }
        },
        {
            "data": {
                "id": "model.1",
                "type": "Linear"
            }
        },
        {
            "data": {
                "id": "model.1_out",
                "type": "Linear"
            }
        },
        {
            "data": {
                "id": "UNIT_model.1_in_0",
                "type": "UNIT"
            }
        },
        ...

    ],
    "edges":[
                {
            "data": {
                "id": 1,
                "source": "model.1",
                "target": "UNIT_model.1_in_0",
                "width": 1
            }
        },
        {
            "data": {
                "id": 32,
                "source": "model.1",
                "target": "UNIT_model.1_in_31",
                "width": 1
            }
        },
        {
            "data": {
                "id": 63,
                "source": "model.1",
                "target": "UNIT_model.1_in_62",
                "width": 1
            }
        },
        ...
    ]
}

工程2. グラフ描画(フロントエンド)コード作成

 JavaScriptでCytoscape1を用いた。Cytoscapeにはいくつかの組み込みのレイアウトが用意されている2(cose, grid, preset)が、ニューラルネットワークというシーケンシャルに処理を行うグラフを描画するにあたって、いずれも見やすさの観点で採用には至らなかった。また組み込みでないレイアウト(cola, klay, dagre)についても比較を行い、Dagreというレイアウト3を採用した。このライブラリはDAG(有向非巡回グラフ)を見やすく配置してくれ、シーケンシャルなデータをツリー上に見やすく配備してくれる。Jsonファイルを読み込んでHTMLを生成するtoHTML.pyコードを作成した。

スクリーンショット 2025-02-11 1.50.20.png

toHTML.py
html_template = f"""<!DOCTYPE html>
<html lang="ja">
<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Summary | TorchLinearViz</title>

    /* Cytoscapeをインポート */
    <script src="https://cdnjs.cloudflare.com/ajax/libs/cytoscape/3.23.0/cytoscape.min.js"></script>
    
    /* Dagreをインポート */
    <script src="https://cdn.rawgit.com/cpettitt/dagre/v0.7.4/dist/dagre.min.js"></script>
    <script src="https://cdn.rawgit.com/cytoscape/cytoscape.js-dagre/1.5.0/cytoscape-dagre.js"></script>
    
    ...略...
"""

# HTML ファイルとして保存
output_html = "epoch_visualizer.html"
with open(output_html, "w", encoding="utf-8") as f:
    f.write(html_template)

なお、環境としてはCDN(ContentDelivery Network)からCytoscape.jsを直接読み込んで使っている。これにより、外部環境を特にセットアップせずに簡単な描画が可能となる。

工程3. フロントとバックエンドを統合

 Jsonでグラフ情報を吐き出すコードを作成し、その後それを読み込んで描画するコードも作成した。あとはこのJson経由の形式を廃止してメンバ変数上でこの管理を行う仕様に変え、pythonコード上で、torchLinearViz.end()が呼び出されたらHTMLファイルを生成する仕組みにするため、toHTML.pyでテンプレートを作っていた内容をそのままこのendメソッドに組み込んだ。全体としてクラスに必要な要素は、初期化のための__init__メソッドと、Epochごとに呼び出されて更新された重みの情報を含んだネットワークの解析結果をメンバ変数self.json_data_listへと追加、保存していくUpdateメソッド、それを組み込んでHTMLファイルに書き込むendメソッドとなった。

torchLinearViz.py
class torchLinearViz:
    def __init__(self, model):
        self.model = model
        self.json_data_list = []
        ...
        
    def extract_and_save_graph():
        ...
        """バックエンドコードを呼び出す"""
        result = analyse_graph()
        self.json_data_list.append(result)

    def update():
        extract_and_save_graph()
        
    def end(self):
        graphDataJson = json.dumps(self.json_data_list)
        html_template = f"""<!DOCTYPE html>
        ...
        <script>
          let graphData = {graphDataJson} // テンプレートにクラス内変数を組み込む
          ...
        </script>
        ...
        """
        with open(output_html, "w", encoding="utf-8") as f:
            f.write(html_template)
    

この時点でリポジトリ構成は以下のように極めてシンプルなものとなっている。

$ tree .
.
└── torchLinearViz
    ├── __init__.py
    ├── analyse_graph.py
    └── torchLinearViz.py

工程4. PyPIにコミットできるように必要なファイルを追加

4-1. pyproject.tomlを作成

 torchはどうせ元々入っている人が使うだろうとも思ったが、自動で入れてくれる方が気が利いていると感じたため、dependencies = ["torch"]を記載してある。あとは各種事務的な情報を記入する。例えばAuthorやLicense、バージョンやソースコードのリンクなどなど。ソースコードのリンクを貼っておくことで、PyPIのページからリポジトリへと飛んでもらえる。AuthorsはPyPIアカウントと紐づいていなければいけないのかと思っていたが、これはどうやらなんでもいいらしい。

4-2. MANIFEST.inを作成

今回は以下のように非Pythonファイルを記載。

MANIFEST.in
include README.md LICENSE

4-3. LICENSEを作成

MITライセンスを添付した。OSIが定義してくれているもの4を参照した。


 以上をもって、あとはビルドするだけである。現時点でのリポジトリ構成は以下のようになっている。
$ tree .
.
├── LICENSE
├── MANIFEST.in
├── README.md
├── pyproject.toml
└── torchLinearViz
    ├── __init__.py
    ├── analyse_graph.py
    └── torchLinearViz.py

工程5. ビルドとテスト

5-1. ビルド

プロジェクトのルートディレクトリで、以下のコマンドを使用。

$ pip install build
$ python -m build

これを実行すると、dist/hoge-X.X.X-py3-none-any.whldits/hoge-X.X.X.tar.gzがルートディレクトリに生成される。

5-2. テスト1 (ローカル)

 最初は4-4.にてビルド済みの.tar.gzファイルをローカルでテストする。ローカルで適当なvenv環境を作って、そこで、

$ pip install dist/hoge-X.X.X-py3-none-any.whl

または

$ pip install dist/oge-X.X.X.tar.gz

を実行。テストコードを任意のディレクトリに作成し、当該ライブラリを呼び出せるか確認。所望のhtmlファイルがカレントディレクトリに生成されていることを確認した。

5-3. PyPIアカウントとtest-PyPIアカウントを作成

 https://pypi.org/account/register/ よりPyPIアカウントを作成。メールアドレス、ユーザー名、パスワードが必須で求められる。また、これらの記入を終えたら2要素認証を追加する。この後の手順で使うtest-pypi5についても同様の手順でhttps://test.pypi.org/account/register/ からアカウントを作成した。手順は同じだが、pypiアカウントとtest-pypiアカウントは別個のサービスであり、片方を作ればもう片方でもログインできるようなシステムではない。アカウントをそれぞれで登録する必要がある。

5-4. テスト2(test-PyPI)

 いきなりアップロードするのは怖いので、最初はtest-pypi5でアップロードした。PyPIに安全にパッケージをアップロードするための公式ツール6であるtwineを用いてアップロードを行う。これはtest-pypiだけでなく、pypiでも同じである。アップロードに際してAPIトークンが必要となる。これを取得して、新規作成した~/.pypircファイルに記載する。

~/.pypirc
[distutils]
index-servers =
   pypi
   testpypi

[pypi]
repository = https://upload.pypi.org/legacy/
username = __token__
password = pypi-XXXX # <- APIトークン

[testpypi]
repository = https://test.pypi.org/legacy/
username = __token__
password = pypi-XXXXX # <- APIトークン

これでアップロードが可能になる。アップロードは以下のコマンドで行う。

$ pip install twine
$ twine upload --repository testpypi dist/*

test-PyPIの良いところは、公開したプロジェクトを簡単に消せるというところである。何度でもやり直しが効くため、本番前の練習として良い環境である。

5-5. PyPIアップロード

$ pip install twine
$ twine upload --repository pypi dist/*

同じようにやる。これは一度公開したら基本的には削除はできず、アップデートを入れる修正が主となるため、慎重に行う。

スクリーンショット 2025-02-11 12.55.59.png

↑アップロードが完了し、見慣れた画面からリリースしたライブラリを確認できるようになる。

所感

 PyPIへのコミットが思っていたより簡単で開かれたコミュニティであることに驚いた。(審査があるのかと思っていた)

  1. URL : https://js.cytoscape.org/

  2. URL : https://blog.js.cytoscape.org/2020/05/11/layouts/

  3. URL : https://github.com/dagrejs/dagre

  4. URL : https://opensource.org/license/mit

  5. URL : https://test.pypi.org/ 2

  6. URL : https://pypi.org/project/twine/

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?