20
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

画像データに対するGraph Neural Network(GNN)入門 ② 〜実践編〜

Posted at

はじめに

前回の記事では、グラフニューラルネットワーク(Graph Neural Network; GNN)が気象予測の領域でブレークしたことに触発され、グラフ理論の基礎からGNN概要や画像データに対してGNNを適用した研究紹介までの入門記事を書きました。

本記事では、より理解を深めるために、Pythonによる画像データに対するGNNの適用方法を紹介します。また、実務や研究で利用する際は、手元の画像データをグラフに変換し、GNNで訓練・評価するまでの手続きが必要なので、これらを一連の流れを説明しようと思います。実装は、Pytorchをベースにグラフデータ用に開発されているPytorch geometricを利用します。

アウトライン

  1. Pytorch geometricの使い方
  2. 画像データのグラフへの変換
    1. ノードの選択
    2. エッジの接続
    3. ノードの特徴量ベクトル
    4. 画像のグラフ変換デモ
  3. GNNモデル設計
    1. グラフ畳み込み層(GCN)
  4. モデル学習・評価
  5. Appendix. CNNとの比較(実装コード)

1. Pytorch geometricの使い方

  • Pytorch geometricでは、ノードやエッジ、特徴量をpytorchのtensor型で保持する
  • ノードとエッジはedge_index(edge_from, edge_toの2次元配列)で表現され、各ノードは特徴量X、ラベルyを持つことができる
  • 下図は、あるグラフがノードを4つ(0 ~ 3)、各ノードは特徴量x_i (3次元)とラベルy_i (1次元)を持っていたときの例を示している

image.png

  • これをtensorで表現するには、edge_fromとedge_toの配列を用意し、それぞれのノードの接続状況を行列で表現する
  • またノードごとの特徴量とラベルも同様にtensorで表現する
  • これをコードで実装すると以下の通りとなる
import torch
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
import networkx as nx

# エッジ情報(ノード間の接続を表現)
edge_from = [0, 0, 1, 1, 1, 2, 2, 3]
edge_to = [1, 2, 0, 2, 3, 0, 1, 1]
edge_index = torch.tensor([edge_from, edge_to])

# 特徴量Xを設定(ノードごと)
x_0 = [0, 1, 2]
x_1 = [1, 2, 3]
x_2 = [2, 3, 4]
x_3 = [3, 4, 5]
x = torch.tensor([x_0, x_1, x_2, x_3])

# ラベルyを設定(ノードごと)
y_0 = [0]
y_1 = [0]
y_2 = [1]
y_3 = [0]
y = torch.tensor([y_0, y_1, y_2, y_3])

# グラフオブジェクトへの変換
data = Data(x=x, y=y, edge_index=edge_index)

# グラフの可視化
nxg = to_networkx(data)
nx.draw(nxg,
        with_labels = True,
        node_color = y,
        alpha=0.5)

image.png

グラフノードは場所の概念を持たないので、ノードの位置は前述の図と異なるが、エッジの接続状況やラベルの色分けは指定した通りとなっている

各要素へのアクセスは以下のメソッドを用いて確認することもできる

# 各要素へのアクセス
print("特徴量X: ")
print(data.x)

print("ラベルy: ")
print(data.y)

print("エッジ情報: ")
print(data.edge_index)

2. 画像データのグラフへの変換

GNNを計算するにあたって、まずは画像をグラフ構造(ノード、エッジ、特徴量、ラベル)に変換する必要がある

論文等では、グラフ・エンコーディング(Graph encoding)等の用語でまとめられていることが多いが、基本的には以下の手順で実行される

2-1. ノードの選択

  • ピクセルベース: 画像の各ピクセルをノードとみなす処理。画像データの情報欠損はないが、グラフサイズが大きくなり、計算リソースが膨大になる可能性がある。
  • 領域ベース(スーパーピクセル): 類似した近傍ピクセルでクラスタを形成し、クラスタを一つのノードとみなす。これにより、グラフのサイズを小さくし、計算効率を向上させることができる。スーパーピクセルの定義方法は様々な手法が提案されている。

2-2. エッジの接続

  • 空間的隣接性による接続: 空間的に近接するピクセルやスーパーピクセルの情報をもとにエッジを張る方法(例. 上下左右のピクセルにエッジを張る)。
  • 特徴量類似度に基づく接続: 画像の特徴量(色、テクスチャなど)の類似度を計算してエッジを張る方法(例. ユーグリッド距離のk-means)。

2-3. ノードの特徴量ベクトル

  • 画像特徴量の利用: RGB、グレースケールやピクセル位置(xy)を特徴量として直接的に利用する。
  • 周辺テクスチャ情報の利用: エッジ検出や局所的なテクスチャ情報をノード特徴を利用する(例. パッチごとにエンコーディングしたベクトル)。

image.png

2-4. 画像のグラフ変換デモ

ここでは、MINIST(手書きの数字画像データ)をグラフデータに変換するケースを示す。

MNISTは画像サイズがそこまで大きくなく(28x28ピクセル)、次元数ものグレースケール(1次元)なので、シンプルにピクセル=ノードとして、エッジ接続はユーグリッド距離の近傍(k=5)と定義、特徴量はグレースケールと元画像のxy座標を含む3次元ベクトルで表現してみる。

image.png

オリジナル画像(2次元配列のグレースケール画像)がグラフに変換されている様子がみてとれる。
しかし、MNISTのオリジナル画像を見ると殆どが黒いピクセル(全体の約80%)であるため、ピクセルベースでのグラフ変換しても、非常に情報量的に疎(スパース)なグラフが生成されている。
このままでは計算効率が悪いので、完全に黒のピクセルはノードとせずにグラフ変換を行うものとする。

ここまでの要件をPythonのコードに落とし込むと以下の様になる。
ここでは処理の分かりやすさの為にユーグリッド距離の計算やk近傍の選定もスクラッチで書き下しているが、ニューラルネットワークの学習を行う際はこれだと処理効率が悪いので、より高速計算に最適化されたライブラリを積極的に利用する(後述)。

# MNISTをグラフに変換
def convert_img2graph(img, top_k = 5):
    # マスク準備
    mask = img != -1
    # 輝度情報の特徴量化
    x = img[mask].reshape(-1, 1).astype(float)
    # xy座標情報の特徴量
    coords = np.stack(np.meshgrid(np.arange(img.shape[1]), np.arange(img.shape[2])), axis=0)
    xcords = coords[0].reshape(1, 28, 28)[mask].reshape(-1, 1)
    ycords = coords[1].reshape(1, 28, 28)[mask].reshape(-1, 1)
    # 特徴量の作成
    x = np.concatenate([x, xcords.astype(float), ycords.astype(float)], axis=0)

    # 要素ごとに類似度(ユーグリッド距離)を計算
    edge_from = []
    edge_to = []
    # ノード(ピクセル)毎に類似度計算(エッジ情報の作成)
    for t in range(len(x)):
        similarity = []
        for u in range(len(x)):
            _ = np.sqrt(np.sum((x[t] - x[u])**2))
            similarity.append(_)
        # 類似度がTOP5のindexを取得
        topk_index = np.argsort(similarity)[::-1][:top_k]
        # グラフのエッジ情報を作成
        edge_from.append(np.repeat(t, top_k).tolist())
        edge_to.append(topk_index.tolist())

    # flatten (1次元配列に変換)
    edge_from = np.array(edge_from).flatten()
    edge_to = np.array(edge_to).flatten()
    # 双方向のエッジを作成(無向グラフの作成)
    edge_index = np.array([edge_from, edge_to])
    edge_index = np.concatenate([edge_index, edge_index[::-1, :]], axis=1)

    # グラフオブジェクトへの変換
    edge_index = torch.tensor(edge_index)
    x = torch.tensor(x)
    data = Data(x=x, y=None, edge_index=edge_index)
    return data

上記のコードを用いてグラフ変換したMNISTをいくつか可視化してみる。全ピクセルををノートに変換するよりもグラフ表現がかなり効率化されている。

image.png

3. GNNモデル設計

画像データのグラフ変換ができたので、次にGNNのモデル設計を行う。

今回は入力データに対してGCN (Graph Convolutional Neural network)を用いてグラフの特徴量抽出を行い、最終的に全結合ネットワークで10クラス分類を行うモデルを設計する。

image.png

グラフ畳み込み層(GCN)

GCNの理論的な説明は前回の記事を参照とする

Pytorch geometricがGCNをサポートしているので実装は割と簡単に行うことができる。グラフ畳み込み処理を行った後に活性化関数(ReLU)に通すことで、非線形な表現と計算効率の向上を実現する(通常の畳み込みニューラルネットワークと同様)。

また、今回は畳み込み処理を複数回繰り返していくが、その度にエッジを貼り直す処理を導入した。これにより高次元空間でグラフノードの関係性を捉えることができると考えた為である。処理内容は上述のconvert_img2graphの関数をPytorch向けに書き直しているが、同様の処理である。

最後の全結合層に入力する前に、グラフデータを特徴量行列に変換するため、ノードの次元を平均値で周辺化し、次元削減を行なっている。ここでは平均値で集約を行なっているが、最大値や最頻値等の様々な圧縮方法が存在する。また、全結合層ではドロップアウト層を組み込み過学習の抑制を行なっている。

# モデル設計(一部)
class GCN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = GCNConv(3, 16) # 入力と出力お特徴量の次元数を定義してインスタンス作成
				self.fc = nn.Linear(120, 10) # 全結合層
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.5)

		def edge_reconnect(self, x, k):
          distance_matrix = torch.cdist(x, x, p=2)
          _, indices = torch.topk(distance_matrix, k)
          edge_from = torch.arange(x.shape[0]).view(-1, 1).repeat(1, k).to(device)
          edge_to = indices
          edge_index = torch.stack([edge_from.view(-1), edge_to.view(-1)], dim=0)
          return edge_index

		def forward(self, data):
        x = self.conv(x, edge_index) # 順伝播の計算では、入力xとノードとエッジに出力層(edge_index)を渡す
				x = self.relu(x)
				edge_index = self.edge_reconnect(x, k = 5) # エッジの貼り直し処理

				・・・

				x = global_mean_pool(x, batch) # ノード次元を周辺化
				x = self.dropout(x)
        x = self.fc(x)
        return x

4. モデルの学習・評価

ここまで解説した内容を一連のコードに落とし込む。
Pytorchの一連のコードの書き方は過去記事 Pytorch Template 個人的ベストプラクティス(解説付き)を参照。

import random
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Subset
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

# リソースの選択(CPU/GPU)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# データ準備
PATH = "./mnist"
transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),torchvision.transforms.Normalize((0.5,), (0.5,))])
train_set = torchvision.datasets.MNIST(root = PATH, train = True, download = True, transform = transform)
test_set = torchvision.datasets.MNIST(root = PATH, train = False, download = True, transform = transform)

# MNISTの画像をランダムサンプリング(計算コスト削減用)
def sample_mnist_dataset(dataset, sampling_percentage):
    # 指定された割合でサンプルの数を計算
    num_samples = int(len(dataset) * sampling_percentage)
    # ランダムにインデックスを選択
    indices = random.sample(range(len(dataset)), num_samples)
    # 選択されたインデックスでサブセットを作成
    subset = Subset(dataset, indices)
    print(f"サンプリングした{('トレーニング' if dataset else 'テスト')}セットのデータ数: {len(subset)}")
    return subset

# 任意(環境により学習データが大き過ぎる場合はランダムサンプリング)
# train_set = sample_mnist_dataset(train_set, 1)
# test_set = sample_mnist_dataset(test_set, 1)


# 画像毎にグラフをまとめてオブジェクト化するクラス
class Mydataset(torch.utils.data.Dataset):
    def __init__(self, dataset):
        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def convert_img2graph(self, img, top_k=5):
        # マスク準備
        mask = img != -1
        # 輝度情報の特徴量化
        x = img.view(-1, 1).float()
        x = x[mask.view(-1)]
        # xy座標情報の特徴量
        coords = torch.stack(torch.meshgrid(torch.arange(img.size(1)), torch.arange(img.size(2))), dim=-1)
        coords = coords.view(-1, 2)
        coords = coords[mask.view(-1)]
        # 特徴量の作成
        x = torch.cat([x, coords.float()], dim=1)

        # ユークリッド距離の計算
        distance_matrix = torch.cdist(x, x, p=2)
        # 類似度がTOP5のノードを取得
        _, indices = torch.topk(distance_matrix, k=top_k, largest=False)
        # エッジインデックスの生成
        edge_from = torch.arange(x.size(0)).view(-1, 1).repeat(1, top_k)
        edge_to = indices
        edge_index = torch.stack([edge_from.view(-1), edge_to.view(-1)], dim=0)
        edge_index = torch.cat([edge_index, edge_index.flip(0)], dim=1)  # 双方向エッジ
        return Data(x=x, y=None, edge_index=edge_index)

    def __getitem__(self, index):
        img, label = self.dataset[index]
        graph_data = self.convert_img2graph(img, top_k = 5)
        return graph_data, label

# モデル定義
class GCN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(3, 16)
        self.conv2 = GCNConv(16, 64)
        self.conv3 = GCNConv(64, 128)
        self.conv4 = GCNConv(128, 256)
        self.conv5 = GCNConv(256, 512)
        self.conv6 = GCNConv(512, 800)

        self.fc1 = nn.Linear(800, 512)
        self.fc2 = nn.Linear(512, 120)
        self.fc3 = nn.Linear(120, 10)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.5)

    def edge_reconnect(self, x, k):
          distance_matrix = torch.cdist(x, x, p=2)
          _, indices = torch.topk(distance_matrix, k)
          edge_from = torch.arange(x.shape[0]).view(-1, 1).repeat(1, k).to(device)
          edge_to = indices
          edge_index = torch.stack([edge_from.view(-1), edge_to.view(-1)], dim=0)
          return edge_index

    def forward(self, data):
        x = data.x
        edge_index = data.edge_index
        batch = data.batch

        x = self.conv1(x, edge_index)
        x = self.relu(x)
        # edge_index = self.edge_reconnect(x, k = 5)

        x = self.conv2(x, edge_index)
        x = self.relu(x)
        # edge_index = self.edge_reconnect(x, k = 5)

        x = self.conv3(x, edge_index)
        x = self.relu(x)
        # edge_index = self.edge_reconnect(x, k = 5)

        x = self.conv4(x, edge_index)
        x = self.relu(x)
        # edge_index = self.edge_reconnect(x, k = 5)

        x = self.conv5(x, edge_index)
        x = self.relu(x)
        # edge_index = self.edge_reconnect(x, k = 5)

        x = self.conv6(x, edge_index)
        x = self.relu(x)
        # edge_index = self.edge_reconnect(x, k = 5)

        x = global_mean_pool(x, batch)

        x = self.dropout(x)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        x = self.fc3(x)

        return x


# データセットの作成
train_graphset = Mydataset(train_set)
test_graphset = Mydataset(test_set)

BATCH_SIZE = 1024
LEARNING_RATE = 1e-2
EPOCH = 150

# データローダーの準備
train_loader = DataLoader(train_graphset, batch_size = BATCH_SIZE, shuffle = True)
test_loader = DataLoader(test_graphset, batch_size = BATCH_SIZE, shuffle = False)

# モデルの定義・損失関数・最適化手法の設定
model = GCN()
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

log_train_loss, log_test_loss = [], []

# 学習ループ
for epoch in tqdm(range(EPOCH)):
    # train loop
    model.train()
    train_batch_loss = []
    for (inputs, labels) in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_batch_loss.append(loss.item())
    # validation loop
    model.eval()
    test_batch_loss = []
    with torch.no_grad():
        for input, label in test_loader:
            input, label = input.to(device), label.to(device)
            output = model(input)
            loss = criterion(output, label)
            test_batch_loss.append(loss.item())
    log_train_loss.append(np.mean(train_batch_loss))
    log_test_loss.append(np.mean(test_batch_loss))

# 学習結果の可視化
plt.figure(figsize=(6,6))
plt.plot(log_train_loss, label="train_loss")
plt.plot(log_test_loss, label="test_loss")
plt.xlabel('EPOCH')
plt.ylabel('LOSS')
plt.title('loss')
plt.legend()
plt.grid()

# 正解率の計算
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for input, label in test_loader:
        input, label = input.to(device), label.to(device)
        outputs = model(input)
        predicted = torch.argmax(outputs, dim=1)
        total += label.size(0)
        correct += (predicted == label).sum().item()

# Calculate accuracy
accuracy = 100 * correct / total
print(f'Accuracy on the test images: {accuracy:.2f}%')

上記のコードでモデルを学習した結果、分類精度が 95.41 % を達成することができた。
グラフ変換の前処理やハイパーパラメータのチューニングで更なる精度の底上げは期待できる。
ちなみに、従来のCNN(Convolutional Neural network)をほぼ同条件で学習・評価した結果 97.6% だったので、概ね同様のパフォーマンスが実現できていることが分かる(CNNによる実装はAppendix参照)。

しかし、モデルの学習速度は従来のCNNの方が圧倒的に早かった。ここから、MNISTの様に比較的、画像内の情報構造が単純であり、分類というシンプルなタスクに対して、わざわざGNNを適用するメリットはなく、(あくまで個人的な推論だが)GNNが向いているのは、

  1. 入力画像内に様々な要素が存在し、その関係性を考慮すると情報濃度が向上する性質を持つ画像(気象画像やレイアウト図等)に対して、
  2. セマンティックセグメンテーションや物体検出といった画像内のオブジェクト間の関係性を考慮する必要がある複雑なタスク、
    が適していると考えられる。

さいごに

グラフニューラルネットワークが気象予測の論文が話題を呼んでいたのをきっかけに、画像データに対するグラフニューラルネットについての基礎を2回にわたって記事にしてみました。グラフ構造は、テーブルデータ、画像、テキストデータのより一般的な表現であり、上手く取り扱えば従来アプローチでは難しかった技術的な課題解決にもつながる可能性を感じることができました。一方で、一般表現であるが故に前処理やモデルの設計の柔軟性が高く、上手く取り扱わないとかえってパフォーマンスの悪化を招く難しさも感じることができました。

自然言語処理を筆頭に画像領域も基盤モデル的なアプローチがトレンドのなか、グラフ領域はまだ基盤モデル的な研究アプローチは見かけません。引き続きモデルデザイン中心の研究トレンドなのか、どこかでのタイミングで大規模データドリブンのアプローチに移行してくのか、今後もトレンドを追っていきたいと思います。

Appendix. CNNによる学習・評価

# モデル定義
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2, stride=2)

        self.conv1 = nn.Conv2d(1, 16, 3)
        self.conv2 = nn.Conv2d(16, 32, 3)

        self.fc1 = nn.Linear(32 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(x.size()[0], -1)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# パラメータ設定
BATCH_SIZE = 1024
LEARNING_RATE = 1e-3
EPOCH = 100

# データローダーの準備
train_loader = torch.utils.data.DataLoader(train_set, batch_size = BATCH_SIZE, shuffle = True)
test_loader = torch.utils.data.DataLoader(test_set, batch_size = BATCH_SIZE, shuffle = False)

# モデルの定義・損失関数・最適化手法の設定
model = CNN()
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

log_train_loss, log_test_loss = [], []

# 学習ループ
for epoch in tqdm(range(EPOCH)):
    # train loop
    train_batch_loss = []
    for (inputs, labels) in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_batch_loss.append(loss.item())
    # validation loop
    model.eval()
    test_batch_loss = []
    with torch.no_grad():
        for input, label in test_loader:
            input, label = input.to(device), label.to(device)
            output = model(input)
            loss = criterion(output, label)
            test_batch_loss.append(loss.item())
    log_train_loss.append(np.mean(train_batch_loss))
    log_test_loss.append(np.mean(test_batch_loss))

# 学習結果の可視化
plt.figure(figsize=(6,6))
plt.plot(log_train_loss, label="train_loss")
plt.plot(log_test_loss, label="test_loss")
plt.xlabel('EPOCH')
plt.ylabel('LOSS')
plt.title('loss')
plt.legend()
plt.grid()

# 正解率の計算
model.eval()
correct, total = 0, 0
with torch.no_grad():
    for input, label in test_loader:
        input, label = input.to(device), label.to(device)
        outputs = model(input)
        predicted = torch.argmax(outputs, dim=1)
        total += label.size(0)
        correct += (predicted == label).sum().item()

# Calculate accuracy (should be 97%)
accuracy = 100 * correct / total
print(f'Accuracy on the test images: {accuracy:.2f}%')
20
10
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
20
10

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?