LoginSignup
1
0

More than 1 year has passed since last update.

Integrated Gradientsでグラフニューラルネットワークを可視化する

Posted at

Integrated Gradientsでグラフニューラルネットワークを可視化する

TL;DR

Integrated Gradients(統合勾配)を利用して、グラフニューラルネットワークによる推論モデルの可視化を行ってみました。よくあるデータセットだと効果が直感的にわかり難いと感じたため、日本語ベンチマーク用データセットであるJGLUEMARK-jaを利用し、日本語の文を係り受け解析した上でグラフ構造に変換して入力し、Itegrated Gradientsによりどの語句として表現されたノードの反応を可視化してみました。

環境

  • dgl==0.9.1
  • ginza==5.1.2
  • networkx==2.8.7
  • capsum==0.5.0

データセット

ヤフー株式会社が提供するGLUEの日本語版です。いくつかのタスク用のデータセットが提供されていますが、今回は Multilingual Amazon Reviews Corpus (MARC)から日本語のレビューを抜粋し、星評価による判断の他にクラウドソーシングによる評価を交えてpositivenegativeの2クラスにラベルされたMARC-jaを使用します。

モジュールのインストール

係り受け解析と特徴量抽出のためにginzaを利用します。グラフニューラルネットワーク用フレームワークはdgl、Itegrated Gradientsの計算にcaptumを使います。

!pip install -U ginza
!pip install dgl
!pip install captum
!pip install torch
!pip install networkx

データの準備

MARC-jaのデータセットがファイルとして用意してある前提です。詳細は以下を参照してください。

全てのデータを利用すると処理が重くなるため、トレーニング用として1000、評価用として200サンプル抜き出します。

import pandas as pd

train_df = pd.read_json(
    'data/marc_ja-v1.0/train-v1.0.json', lines=True)
train_df = pd.concat([train_df[train_df['label'] == 'positive'].sample(
    500), train_df[train_df['label'] == 'negative'].sample(500)]).sample(frac=1)
test_df = pd.read_json(
    'data/marc_ja-v1.0/valid-v1.0.json', lines=True)
test_df = pd.concat([test_df[test_df['label'] == 'positive'].sample(
    100), test_df[test_df['label'] == 'negative'].sample(100)]).sample(frac=1)

display(train_df)
display(test_df)

文をグラフ構造に変換する

ginza(SpaCy)を利用して文を係り受け解析し、NetworkXのグラフ構造を経てdgl用のグラフに変換します。変換部分のコードは以下を参考にしています。

import spacy

nlp = spacy.load('ja_ginza')
import dgl
import networkx as nx
import matplotlib.pyplot as plt
import japanize_matplotlib
import numpy as np


def create_graph(doc: spacy.tokens.doc.Doc):
    edges = []
    vector_dict = {}
    node_id2names = {}
    for token in doc:
        vector_dict[token.text] = token.vector
        for child in token.children:
            edges.append((token.text, child.text))
            vector_dict[child.text] = child.vector
    graph = nx.DiGraph(edges)
    for i, (n, attr) in enumerate(graph.nodes.items()):
        graph.nodes[n]['feat'] = vector_dict[n]
        graph.nodes[n]['name_index'] = i
        node_id2names[i] = n
    g = dgl.from_networkx(graph, ['name_index', 'feat'])

    return g, node_id2names


def plot_tree(g, node_id2names, label, with_weight=False):
    relabels = {}
    base_size = 80
    base_color = [.4, .5, .9, .65]
    node_sizes = []
    for n, attr in g.nodes.items():
        relabels[n] = node_id2names[attr['name_index'].item()]
        if with_weight:
            scale = 1. + attr['weight'].detach().numpy()
            node_sizes.append(np.exp(scale**2) * base_size)
    g2 = nx.relabel_nodes(g, relabels)

    pos = nx.nx_agraph.graphviz_layout(g2, prog='dot')
    fig = plt.figure(figsize=(15, 10))
    ax = fig.add_subplot(1, 1, 1)
    if with_weight:
        nx.draw(g2, pos, with_labels=True, node_size=node_sizes,
                node_color=base_color, arrowsize=15, font_family='IPAexGothic', ax=ax)
    else:
        nx.draw(g2, pos, with_labels=True, node_size=300,
                node_color=base_color, arrowsize=15, font_family='IPAexGothic', ax=ax)
    ax.set_title(
        f'Label: {"positive" if label == 1 else "negative"}')
    plt.show()


sentence = train_df.iloc[0]['sentence']
print(f'Sentence: {sentence}')

g, node_id2names = create_graph(
    nlp(sentence))
plot_tree(g.to_networkx(
    node_attrs=['name_index', 'feat']), node_id2names, label=1)

以下のように文章がグラフ構造に変換されます。

Sentence: このゲーム、すぐに倉庫が満杯になりイライラしてくる事が多くなってきます。それ理由は倉庫枠が50あったとすると普通は50種類収納できると考えられると思いますが、違います。総数が50個までしか収納できないのです。分かりやすく例えると、ニンジン6個と小麦9個を倉庫に入れると枠を15使ってしまうのです。最初のうちは気にならないのですが、プレイし続けると倉庫のやりくりに頭を抱える事が多くなってやる気が削がれます。

以下でトレーニング、評価用のデータセットを全てグラフ構造に変換します。

def create_features_and_labels(df):
    features = []
    labels = []
    node_id2names = []

    for index, row in df.iterrows():
        sentence = row['sentence']
        g, names = create_graph(nlp(sentence))

        features.append(g)
        label = 1 if row['label'] == 'positive' else 0
        label = np.array(label)
        labels.append(label)
        node_id2names.append(names)

    return features, labels, node_id2names


train_features, train_labels, train_node_id2names = create_features_and_labels(
    train_df)
test_features, test_labels, test_node_id2names = create_features_and_labels(
    test_df)

モデルの準備

推論モデルを準備します。GraphConvを利用した比較的単純なモデルです。ラベルは1:positive0:negativeに変換しているため、出力はsigmoidで0-1の範囲に変換しています(後述の損失関数ではbinary_cross_entropyを使用しています)。

import torch.nn as nn
from dgl.nn import GraphConv, MaxPooling, AvgPooling


class GCN(nn.Module):
    def __init__(self,
                 in_feats,
                 num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, 32, allow_zero_in_degree=True)
        self.conv2 = GraphConv(32, 16, allow_zero_in_degree=True)
        self.pooling = AvgPooling()
        self.linear1 = nn.Linear(16, 256)
        self.linear2 = nn.Linear(256, num_classes)

    def forward(self, g, h):
        h = self.conv1(g, h)
        h = self.conv2(g, h)
        h = self.pooling(g, h)
        h = self.linear1(h)
        h = F.relu(h)
        h = self.linear2(h)
        out = F.sigmoid(h)

        return out


model = GCN(train_features[0].ndata['feat'].shape[1], 1)
print(model)
GCN(
  (conv1): GraphConv(in=300, out=32, normalization=both, activation=None)
  (conv2): GraphConv(in=32, out=16, normalization=both, activation=None)
  (pooling): AvgPooling()
  (linear1): Linear(in_features=16, out_features=256, bias=True)
  (linear2): Linear(in_features=256, out_features=1, bias=True)
)

トレーニング

import torch
import torch.nn.functional as F
import numpy as np
from statistics import mean


def train(model, train_features, train_labels, report_period=500):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    losss = []

    for step, (feature, label) in enumerate(zip(train_features, train_labels)):
        g = feature.add_self_loop()
        logits = model(g, feature.ndata['feat'])
        _label = torch.tensor(label).reshape(-1, 1).float()

        loss = F.binary_cross_entropy(logits, _label)
        losss.append(loss.clone().detach().numpy())
        # print(logits, _label, loss)

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if step > 0 and step % report_period == 0 or step == (len(train_features) - 1):
            print(f'step: {step + 1}, loss: {np.array(losss).mean()}')
            losss = []


for i in range(5):
    train(model, train_features, train_labels)

評価

from sklearn.metrics import classification_report

preds = []
labels = []

for feature, label in zip(test_features, test_labels):
    g = feature.add_self_loop()
    pred = model(g, feature.ndata['feat'])

    preds.append(1 if pred > 0.5 else 0)
    labels.append(label)

confusion_matrix = np.zeros((2, 2))
for p, l in zip(preds, labels):
    confusion_matrix[p][l] += 1

cr = classification_report(labels, preds, output_dict=True)
pd.DataFrame(cr).T

評価用データセットの評価結果は以下のとおりです。概ね0.78〜0.80程度のAccになります。

可視化

captumIntegratedGradientsでIntegrated Gradientsを計算します。計算した統合勾配はノード分出力されますが、わかりやすくするために標準化した上で、ノードの属性であるweightに割り当てています。このweightを利用して、networkXによる可視化時にノードの大きさを調整しています。

from captum.attr import IntegratedGradients
from functools import partial

# test_target_graph_index = 64 # positive sample
# test_target_graph_index = 77 # negative sample
test_target_graph_index = 77
g = test_features[test_target_graph_index].add_self_loop()
label = test_labels[test_target_graph_index]
sentence = test_df.iloc[test_target_graph_index]["sentence"]

h = g.ndata['feat'].clone().requires_grad_(True)
ig = IntegratedGradients(partial(model, g))

feat_attr = ig.attribute(
    h, target=0, internal_batch_size=g.num_nodes(), n_steps=10)

node_weights = feat_attr.abs().sum(dim=1)
node_weights = (node_weights - node_weights.min()) / node_weights.max()
g.ndata['weight'] = node_weights

print(f'Sentence: {sentence}')
plot_tree(g.remove_self_loop().to_networkx(
    node_attrs=['name_index', 'feat', 'weight']), test_node_id2names[test_target_graph_index], label=label, with_weight=True)

positive example

negative example

参考文献

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0