大規模言語モデル入門 の輪読会を開催したので、発表に使った資料を一部修正して公開します。
自然言語推論
- JNLI
- 自然言語推論のデータセット
- 2つの文章の関係性を3パターンに分類する。(分類)
- 2つの文章を1つの文章として繋げて、エンコーダにぶち込む。
- entailment=含意, contradiction=矛盾, neuetral=中立
意味的類似度計算
- JSTS
- 意味的類似度計算のデータセット
- 2つの文章の類似度を計算する。(回帰)
- 2つの文章を1つの文章として繋げて、エンコーダにぶちこんでいる。
- 回帰タスクなので相関係数や平均二乗誤差などで評価している。
- 解答スコア:JSTSのスコアは0-5をとるように作られている。
多肢選択式質問応答
- JCommonsenseQA
- 多肢選択式質問応答のデータセット。
- 5択のクイズ問題形式になっている。
- 質問と回答を繋げた5つの文章をそれぞれエンコーダに入れてスコアを出力し、最終的に1番高い選択肢を回答とする。
メモリ効率の良いファインチューニング
ハードウェアの制限状況下でファインチューニングを行うためのテクニックが4つある。Kaggleとかで使えそうな話。
- 自動混合精度演算
- FP32とFP16を使い分けて高精度かつ高効率な学習をする手法。
- ネットワークの前向き計算と誤差逆伝播はFP16を使う。
- パラメータ更新にはFP32を使う。
- 値が小さくなりすぎて0になるのを防ぐために、損失スケーリング(=値を定数倍する)なども行っている。
- 勾配累積
- 小さなバッチサイズで計算した勾配を集約することで、メモリ使用量を抑えながら実質のバッチサイズを増やす手法
- 例えばバッチサイズ16で2回勾配累積を行うと、バッチサイズ32で学習した場合と同じぐらいの精度が得られる。
- 勾配チェックポインティング
- 前向き計算の途中の計算結果を間引く方法
- torch.utils.checkpoints.checkpoints
- LoRAチューニング
- LoRA(Low-Rank Adaptation)
- 普通のファインチューニングはh=WXのWを更新するが、LoRAはh=(W+ΔW)Xとすることで、ΔWをチューニングしている。
- 差分の行列とはなんぞや...?
日本語大規模言語モデルの比較
-
LUKE (Language Understanding with Knowledge-based Embeddings)
- Wikipediaから得られるエンティティ情報を取り入れた日本語モデル
- エンティティとは
- 特定の概念(人物、場所、組織、出来事)を指す。
- エンティティを取り入れることによって、固有表現認識や質問応答などのタスクの性能を向上させる。
-
DeBERT V2 (Decoding-enhanced BERT with disentangled attention)
- Microsoftが開発したBERTのアーキテクチャを改良した日本語モデル
- Enhanced Mask DecoderとnGiEという仕組みが搭載されている。
性能比較
- JCommonsenseQAはちょっと難しめのタスクであることがわかる。
- DeBERTはアーキテクチャと学習データがOSCAR分多いから性能が高いと考えられる。