本記事は,京都大学人工知能研究会KaiRAのAdvent Calender 1日目の記事です.
作ったもの
京都大学人工知能研究会KaiRAの11月祭で出すAIデモの1つとして、下記のようなことを行うAIを開発しました。
- 歌詞を入力すると、その歌詞がどのアーティストっぽいかを判定する
- 予測結果だけでなく、どの部分に着目して判定したかをXAI手法で可視化
Google Colaboratoryで動きますのでぜひ試してみてください。
StreamlitでWebアプリを作りました!(ただし予測根拠の表示は無し)
動作イメージ
試しに、あいみょんの「マリーゴールド」のサビ部分を入力すると、以下のような結果が表示されます。
あいみょんの曲なので当然「あいみょん:90%」と高い確率になるのですが、その次に 「スピッツ:80%」 となっています。
あいみょんはスピッツに影響を受けており、あいみょんの代表曲の1つ「君はロックを聴かない」はスピッツの「醒めない」に感銘を受けたことをきっかけに作曲されたとインタビューで語っています1。
歌詞がどのアーティストから影響を受けたのかも伺える、そんな結果となりました。
さらに、予測結果だけでなくその根拠も可視化できるようになっています。
オレンジ色でハイライトされている単語は「そのアーティストっぽい単語」と認識されている部分で、青色はその逆です。
しかしながら……
「判定理由を見てもよくわからん」
というのが正直なところです。
これは、採用したXAI手法に問題があると考えられます。
予測根拠の説明にはLIME(Local Interpretable Model-agnostic Explanation) という手法を用いていて、簡単にいうと「元モデルを線形回帰モデル(などの解釈可能モデル)で近似することで、説明を得る」という手法です。
※以前私が書いた記事で、LIMEについて紹介したものがあったので、興味のある方はご覧ください。
ここでは元モデル(中身はニューラルネットワーク)を、各単語の有無を特徴量とする線形回帰モデルで近似し、線形モデルの回帰係数の重みをハイライトの色に反映しています。
ということは「判定理由」に現れる結果は「単語の有無」のみを考慮した説明になっており
- 単語同士のつながり
- 単語の順番
などは考慮されていないわけです(同じ単語には同じハイライトがついていることからもわかると思います)。
直感的にも、
「あ~この曲、あいみょんっぽい」
と感じるのは「この単語が含まれているから」ではなくて、「全体的にあいみょんっぽい感じがする」とか「単語の並び方があいみょんっぽい」とかそういう理由からではないでしょうか。
なので、LIMEによる説明はあまり向いていなかったのかもしれません…。
中身の話(モデル、データ、精度)
ここからは、モデルの中身や学習データ、精度について簡単に説明していきます。
モデル
自然言語処理タスクなので、RNN系のネットワークを使うことが多いと思いますが、今回はCNNを使いました。character-level-CNN2 呼ばれるアーキテクチャを採用しています。
日本語は、英語のように単語がスペースで区切られていないので分かち書きが面倒です。そこで「単語ごとではなく文字レベルで分けてしまおう!」というのがcharacter-level-CNNのアイディアです。
下の図は論文2のFigure1からの引用で、モデルの概要を表しています。
ざっくり図を説明すると次のような感じです。
- テキストをOne-hotベクトルに変換
- 畳み込み演算をLength方向についてのみ行う
- Max-poolingを間に挟んだりして層を重ねる
- 全結合層に通して出力
画像の場合、畳み込み領域は縦方向にも横方向にもずらしながら演算を行っていきますが、今回の1次元畳み込み ではLength方向(=文字が並ぶ方向)のみにずらしながら演算を行っていきます。
ちなみに、Length方向の長さは固定長でなければならないので、最大長$l_0$をあらかじめ設定しておき、それを超える分は切り捨て、最大長に満たないぶんはゼロパディングを行います。
PyTorchによる実装は以下のとおりです。1次元畳み込みはnn.Conv1d
が使えます。
また、畳み込み層に通す前にEmbeddingを組み込んだり、BatchNorm・Dropoutなどを加えています。
こちら↓の記事を参考にさせていただきました。
class CharacterCNN(nn.Module):
def __init__(self, num_classes ,embed_size=128, max_length=200, filter_sizes=(2, 3, 4, 5), filter_num=64):
super().__init__()
self.params = {'num_classes': num_classes ,'embed_size':embed_size, 'max_length':max_length, 'filter_sizes':filter_sizes, 'filter_num':filter_num}
self.embed_size = embed_size
self.max_length = max_length
self.filter_sizes = filter_sizes
self.filter_num = filter_num
self.embedding = nn.Embedding(0xffff, embed_size)
self.conv_layers = nn.ModuleList([
nn.Conv1d(embed_size, filter_num, filter_size) for filter_size in filter_sizes
])
self.fc1 = nn.Linear(filter_num * len(filter_sizes), 64)
self.batch_norm = nn.BatchNorm1d(64)
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(64, num_classes)
def forward(self, x):
embedded = self.embedding(x).transpose(1,2)
conv_outputs = []
for conv_layer in self.conv_layers:
conv_output = F.relu(conv_layer(embedded))
pooled = F.max_pool1d(conv_output, conv_output.size(2)).squeeze(2)
conv_outputs.append(pooled)
convs_merged = torch.cat(conv_outputs, dim=1)
fc1_output = F.relu(self.fc1(convs_merged))
bn_output = self.batch_norm(fc1_output)
do_output = self.dropout(bn_output)
fc2_output = self.fc2(do_output)
return fc2_output
データ
入力単位は「歌詞ブロック」としています。
例えば「マリーゴールド」の歌詞は
風の強さがちょっと
心を揺さぶりすぎて
真面目に見つめた
君が恋しいでんぐり返しの日々
可哀想なふりをして
だらけてみたけど
希望の光は目の前でずっと輝いている
幸せだ
...
ですが、これを空行で区切った
風の強さがちょっと
心を揺さぶりすぎて
真面目に見つめた
君が恋しい
が「歌詞1ブロック」です。
そしてこれを文字ごとに区切ります。
["風","の","強","さ","が","ち","ょ","っ","と","\n","心","を","揺","さ","ぶ","り","す","ぎ","て","\n","真","面","目","に","見","つ","め","た","\n","君","が","恋","し","い"]
改行の度合いにもアーティストっぽさが出るかなと思って、改行文字も含めるようにしています。
そして、各文字をord()
関数でUnicodeのコードに変換し、nn.Embedding()
に渡しています。
歌詞全体を入力したいときは、各ブロックをそれぞれモデルに入力し、出力の平均を予測結果とすればよいですね。
学習方法
学習は次のように行いました。
- 41クラス分類モデルを学習させる
- 収集した全41アーティストのうちどのアーティストの歌詞かを分類させる
- Embeddingを学習させる目的
- Embeddingの重みを固定・全結合層を取り換えて、各アーティストに対して2クラス分類モデルを学習させる
- 「あいみょんの歌詞か、そうでないか」「スピッツの歌詞か、そうでないか」…というモデルを41個学習させる
- Embedding以外の畳み込み層・全結合層を学習させる
train-validの分割はGroupKFold($K=5$)で行いました。今回の場合、Group=曲です。
同じ曲の中には似たような歌詞が出てきやすいため、trainに含まれる曲の歌詞がvalidationデータに混ざっていると、不当に高い精度が出てしまうおそれがあるからです。
精度
各アーティストモデル(2値分類)のF1スコアのFold平均は下図のとおりです。0.8前後でした。
ソースコード
ソースコードはGitHubに上げてあります。