Help us understand the problem. What is going on with this article?

【 self attention 】簡単に予測理由を可視化できる文書分類モデルを実装する

More than 1 year has passed since last update.

はじめに

Deep Learning モデルの予測理由を可視化する手法がたくさん研究されています。
今回はその中でも最もシンプルな(しかし何故かあまり知られていない)self attentionを用いた文書分類モデルを実装したので実験結果を紹介します。
この手法では、RNNモデルが文書中のどの単語に注目して分類を行ったか可視化することが可能になります。

2019/04追記

本記事で紹介したモデルをAllenNLPを使用して書き直した記事を公開しました。

attentionの復習

attentionとは(正確な定義ではないですが)予測モデルに入力データのどの部分に注目するか知らせる機構のことです。
attention技術は機械翻訳への応用が特に有名です。
例えば、日英翻訳モデルを考えます。翻訳モデルは”これはペンです”という文字列を入力として"This is a pen"という英文を出力しますが、「pen」という文字を出力する際、モデルは入力文の「ペン」という文字に注目するはずです。このように入力データのある部分に「注目する=attention」という機構を予測モデルに組み込むことで、種々のタスクにおいいて精度が向上することが報告されてきました。
また、このattentionを可視化することで「入力データのどの部分に注目して予測を行ったか」という形で予測理由の提示を行うことができます。
attentionについての説明と実装は

がとても参考になります。

self attention を利用した分類

今回は、attentionの技術を利用して、予測理由が可視化できる文書分類モデルを実装しました。
self-attentive sentence embedding という論文の手法を単純化したものになります。
この手法は次のような手順で予測を行います。

  1. bidirectional LSTMで文書を変換
  2. 各単語に対応する隠れ層(下図$h_i$)を入力とし、予測の際その単語に注目すべき確率(self attention 下図$A_i$)をNeural Networkで予測
  3. self attention の重み付で各単語に対応する隠れ層を足し合わせたものを入力とし、Neural Networkで文書のラベルを予測

この$A_i$を可視化してやれば、モデルが予測の際どの単語に注目したかを知ることができます。
(オリジナル論文では複数個のself attentionを利用する方法が提案されているのですが、今回は簡易のためattentionは1種類としています。)

image.png

実装

上記手法をpytorchで実装してみました。
bidirectional LSTMの部分は次のような感じになります。

class EncoderRNN(nn.Module):
    def __init__(self, emb_dim, h_dim, v_size, gpu=True, v_vec=None, batch_first=True):
        super(EncoderRNN, self).__init__()
        self.gpu = gpu
        self.h_dim = h_dim
        self.embed = nn.Embedding(v_size, emb_dim)
        if v_vec is not None:
            self.embed.weight.data.copy_(v_vec)
        self.lstm = nn.LSTM(emb_dim, h_dim, batch_first=batch_first,
                            bidirectional=True)

    def init_hidden(self, b_size):
        h0 = Variable(torch.zeros(1*2, b_size, self.h_dim))
        c0 = Variable(torch.zeros(1*2, b_size, self.h_dim))
        if self.gpu:
            h0 = h0.cuda()
            c0 = c0.cuda()
        return (h0, c0)

    def forward(self, sentence, lengths=None):
        self.hidden = self.init_hidden(sentence.size(0))
        emb = self.embed(sentence)
        packed_emb = emb

        if lengths is not None:
            lengths = lengths.view(-1).tolist()
            packed_emb = nn.utils.rnn.pack_padded_sequence(emb, lengths)

        out, hidden = self.lstm(packed_emb, self.hidden)

        if lengths is not None:
            out = nn.utils.rnn.pad_packed_sequence(output)[0]

        out = out[:, :, :self.h_dim] + out[:, :, self.h_dim:]

        return out

attentionクラスです。
LSTMの隠れ層を入力として、各単語へのattentionを出力します。

class Attn(nn.Module):
    def __init__(self, h_dim):
        super(Attn, self).__init__()
        self.h_dim = h_dim
        self.main = nn.Sequential(
            nn.Linear(h_dim, 24),
            nn.ReLU(True),
            nn.Linear(24,1)
        )

    def forward(self, encoder_outputs):
        b_size = encoder_outputs.size(0)
        attn_ene = self.main(encoder_outputs.view(-1, self.h_dim)) # (b, s, h) -> (b * s, 1)
        return F.softmax(attn_ene.view(b_size, -1), dim=1).unsqueeze(2) # (b*s, 1) -> (b, s, 1)

最後にattentionを利用して実際に文書分類を行う部分です。

class AttnClassifier(nn.Module):
    def __init__(self, h_dim, c_num):
        super(AttnClassifier, self).__init__()
        self.attn = Attn(h_dim)
        self.main = nn.Linear(h_dim, c_num)


    def forward(self, encoder_outputs):
        attns = self.attn(encoder_outputs) #(b, s, 1)
        feats = (encoder_outputs * attns).sum(dim=1) # (b, s, h) -> (b, h)
        return F.log_softmax(self.main(feats)), attns

これらのNeural Networkを同時に学習させます。

実験

今回はIMDB映画レビューのネガポジ判別を行ってみます。
このデータは映画のレビューに対して、positiveかnegativeかをタグ付けしたデータセットで、torchtextなどから簡単に利用することができます。

  • 単語の分散表現の次元は100
  • LSTMの隠れ層の次元は32

と比較的小さなNetworkを利用しましたが、90%程の精度を達成できました。

予測理由(attention) の可視化

検証用データを用いてattentionを可視化してみました。
赤いハイライトの濃さがattentionの強さを表しています。
(Qiitaってspanタグ使えないんだ。。)

正解:POSITIVE 予測:POSITIVE なデータ

pospos.png
good, brilliant などの、単語の意味自体がpositiveなものや、highly recommendといった映画のレビューの文脈ではpositiveといえるものがハイライトされているのが観察できます。

正:NEGATIVE 予測:NEGATIVE なデータ

negneg.png

negativeなレビューの場合も、worstやhateといった単語の意味自体がnegativeなものが強くattentionされています。

正解:POSITIVE 予測:NEGATIVE なデータ

posneg.png

これはpositiveなレビューなのにnegativeと予測してしまった例です。attentionを観察するかぎりpassという単語に注目してnegativeと予測してしまったようです。
たしかに、"pass"という単語は「こんな映画みても無駄だからpassしろ」といった感じでnegativeに使われることも多いですが、今回は"pass it on"で「他の人にもこの映画を広めてほしい」というニュアンスで使われていると思います。たぶん。(英検3級並感)
このイディオムをとらえきれなかったようですね。

正解:NEGATIVE 予測:POSITIVE なデータ

negpos.png
これは逆にnegativeなレビューなのにpositiveと判断してしまった例です。「この映画の製作陣が二度と再結成しないことを祈るよ。」というかなり皮肉に富んだレビューです。はっきりとnegative / positiveを表している単語が少なく、この言い回しの真意を読み取れなかったようです。

まとめ

今回はself attentionを使用して、予測理由が簡単に可視化できる文書分類モデルを実装しました。どの単語に注目したか可視化するだけでも結構説得力のあるモデルになっていると思います。
予測を間違えたデータの分析も予測理由の可視化ができるとわかりやすいですね。

コード

https://github.com/nn116003/self-attention-classification

参考文献

[1]Zhouhan Lin, Minwei Feng, Cicero Nogueira dos Santos, Mo Yu,Bing Xiang, Bowen Zhou & Yoshua Bengio A STRUCTURED SELF-ATTENTIVE SENTENCE EMBEDDING, ICLR 2017

itok_msi
python/pytorch/caffe/lasagne/deep learning/機械学習 機械学習系の論文について、いろいろな追加実験をした結果を記事にしていきたいです。
ntt-data-msi
数理科学とコンピュータサイエンスの融合!!
http://www.msi.co.jp/
Why not register and get more from Qiita?
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
Comments
No comments
Sign up for free and join this conversation.
If you already have a Qiita account
Why do not you register as a user and use Qiita more conveniently?
You need to log in to use this function. Qiita can be used more conveniently after logging in.
You seem to be reading articles frequently this month. Qiita can be used more conveniently after logging in.
  1. We will deliver articles that match you
    By following users and tags, you can catch up information on technical fields that you are interested in as a whole
  2. you can read useful information later efficiently
    By "stocking" the articles you like, you can search right away
ユーザーは見つかりませんでした