こんにちは。今回、LLMを使用したアプリケーションを作成することになり、LLMによるAPIコールがどのような仕組みで行われているのかについて学習したくなってきたので、有名なToolformerの論文を読んでまとめたいと思います。
概要
出力例
文章生成の適切な位置でAPIコールを意味する文字列を生成してAPIを使用している例
著者、論文情報
- 論文名: Toolformer: Language Models Can Teach Themselves to Use Tools
- 著者: Timo Schick, Jane Dwivedi-Yu, Roberto Dessì, Roberta Raileanu, Maria Lomeli, Luke Zettlemoyer, Nicola Cancedda, Thomas Scialom
- 所属: Meta AI Research
- タイトル: Toolformer: Language Models Can Teach Themselves to Use Tools
先行研究の問題点
- 現行のLarge Language Models (LLMs)はゼロショット・Few-Shotタスクで顕著な能力を示す一方、以下の基本的機能で課題を抱えている:
- アップデートされた情報へのアクセス不足
- Hallucination(虚偽の事実生成)の問題
- 計算スキルの欠如
- 時間的知識(temporal awareness)の不足
- Low-resource言語への対応力の弱さ
- ツール利用の既存手法では以下の限界がある:
- 大量の人間によるアノテーションを必要とする
- ツール使用が特定タスクに限定される
本研究の新規性
- 自己教師あり学習(self-supervised learning)を用い、言語モデルが外部ツールを独立的に利用する能力を習得。ツールには以下が含まれる:
- 計算機
- Q&Aシステム
- 検索エンジン(Wikipedia Search)
- 翻訳システム(Machine Translation System)
- カレンダー
- Toolformerは以下を実現:
- アノテーションなしでAPIコールを学習
- モデルがツール使用のタイミング、適切なツール選択、引数の決定を独自に判断
- 言語モデリング能力を損なわず、多用途での適用が可能
有効性評価の方法と結果
-
方法:
- ゼロショットで以下のタスク群を評価:
- 質問応答(Question Answering)
- 計算タスク
- 多言語タスク(Multilingual QA)
- 時系列タスク(Temporal Reasoning)
- 言語モデリング能力をパープレキシティ(Perplexity)で評価。
- ゼロショットで以下のタスク群を評価:
-
結果:
- Toolformerは下流タスクでGPT-JやGPT-3(175B)を超える性能。
- 特定タスクでのAPI使用率が高く、例としてMathタスクで97.9%が計算機APIを使用。
- 言語モデリングのコア性能を損なわず、Perplexityが維持されている。
今後の課題
-
ツールの連携利用(Chained Tool Usage):
- 一つのツールの出力を別のツールの入力として使用する機能が未対応。
-
対話型ツール利用(Interactive Tool Use):
- 検索結果のリファインや複数結果の操作が不可能。
-
プロンプト依存性(Prompt Sensitivity):
- モデルがツール利用を判断する際、入力の表現に敏感。
-
データ効率性の問題(Sample Efficiency):
- 有効なAPIコールを得るためのデータサンプリングが非効率的。
-
計算コスト最適化:
- APIコールのコストを考慮した最適化が未実装。
Introduction
-
Large Language Models (LLMs) はゼロショットやFew-Shotタスクで高い性能を発揮するが、以下のような基本的な機能で課題を抱える:
- 最新情報へのアクセスが困難
- 事実のhallucination(誤情報生成)のリスク
- 計算能力の不足
- 時間的な文脈(temporal awareness)の欠如
- Low-resource言語の理解が不十分
-
現在のLLMsの制限に対処する一つの方法は、外部ツール(検索エンジン、計算機、カレンダーなど)を利用可能にすること。
-
既存のツール利用手法の課題:
- 大量の人間アノテーションが必要
- ツール利用が特定タスクに限定され、汎用性が低い
-
本研究では以下を提案:
-
Toolformer: 外部ツールを自己教師あり学習で利用可能にするモデル
- モデル自身がどのツールを、いつ、どのように利用すべきかを学習
- APIコールの利用例を人間による大量のアノテーションなしに取得
-
Toolformer: 外部ツールを自己教師あり学習で利用可能にするモデル
-
Toolformerの目標:
- モデルが汎用性を損なわず、広範なツール利用を可能にする
- 自己の予測精度向上に役立つツール利用を自律的に判断可能にする
Approach
本論文では、人間による大量のアノテーションなしにLLMにAPIの呼び出し方を教える。例として「The Eiffel Tower is located in」という文字列を挙げる。
数式なしの解説
- 入力:The Eiffel Tower is located in
- Ground Truth:The Eiffel Tower is located in Paris
- データセットC:Toolformerを学習するためのデータセットC*の元となるデータセット
以下の流れToolformerの学習を行う。
- 事前学習済みLLM($M$とする)を用意する
- $M$にAPIの生成方法の例をプロンプトとして与える(例:Wikipediaの呼び出し方や計算機の呼び出し方)
- $M$に「The Eiffel Tower is located in」の次のトークンの確率を推定
- 「The Eiffel Tower is located in <API>」の確率が閾値以上ならAPIを呼び出し、埋め込み「例:"The Eiffel Tower is located in" + {Wikipedia APIの出力}」
- "The Eiffel Tower is located in" + {Wikipedia APIの出力} を入力として次のトークンを出力(正解はParis)
- Parisの出現確率がある程度上がったら(= 損失が閾値以上下がったら)APIコールを含めた文字列を新たなデータセットC*とする
- C* を使用してToolformerを学習
数式を用いた解説
ステップ1: サンプリング (Sampling API Calls)
-
モデル $M$ に対し、プロンプト $P(x)$ を与え、APIコールを挿入する位置をサンプリング
-
APIコールを開始する確率:
$$
p_M(\text{<API>} | P(x), x_{1:i-1})
$$ -
サンプリング閾値 $\tau_s$ を用いて、APIコールの候補位置を決定:
$$
I = \lbrace i ,|, p_M(\text{<API>} | P(x), x_{1:i-1}) > \tau_s \rbrace
$$
-
-
各位置 $i \in I$ に対し、APIコール候補を生成:
- $c_i^1, c_i^2, \ldots, c_i^m$: APIコールの候補リスト
ステップ2: 実行 (Executing API Calls)
-
生成したAPIコールを実行し、結果 $r_i$ を取得
- 各APIは以下の形式を持つ:
$$
c_i = (a_c, i_c) \quad \text{(API の名前 $a_c$ と入力 $i_c$)}
$$ - 応答を含むシーケンス:
$$
e(c_i, r_i) = \text{<API>} a_c(i_c) \to r_i \text{</API>}
$$
- 各APIは以下の形式を持つ:
-
$a_c(i_c)$の例として例えば"Wikipedia (Eiffel Tower)"などが挙げられる
-
$\to$の先にAPIの返却値を埋め込む ($=r_i$)
ステップ3: フィルタリング (Filtering API Calls)
-
各APIコール $c_i$ の応答が予測精度向上に寄与するかを評価:
- 予測損失:
$$
L_i(z) = -\sum_{j=i}^n w_{j-i} \cdot \log p_M(x_j | z, x_{1:j-1})
$$ - $z$ はAPIによる付加情報を示す。ないかもしれないし、APIを読んだが帰ってきていない文字列かもしれないし、APIを呼んで値も含んでいる文字列かもしれない
- $x_{1:j-1}$は、これまでの文章を表す
- $w_{j-i}$: トークン間の重要度を表す重み。APIコールが予測トークンに近いほど、つまり、APIコールの直後に良い回答が生まれるほど、そのAPIコールが重要とみなされる
- 予測損失:
-
APIコールの有効性を次の損失差で評価:
応答付きの損失が応答なしの損失と比較してある程度以上改善されているかをチェック。改善されていればあらたなデータセットにAPIの呼び出し部分が加えられる。- 応答付きの損失:
$$
L_i^+ = L_i(e(c_i, r_i))
$$
$c_i$はAPI呼び出しのための文字列、$r_i$はAPIの戻り値、$e(c_i, r_i)$はAPIを呼び出した後の呼び出しと結果を整形した文字列。これらを踏まえて次単語を予測し、クロスエントロピーで損失を取る。 - 応答なしの損失:
$$
L_i^- = \min(L_i(\epsilon), L_i(e(c_i, \epsilon)))
$$
$\epsilon$は空文字列を表す。つまり、APIを呼んだ際の比較対象は、APIそもそも呼ばなかった場合か、APIを読んだが結果が帰ってこなかった方の良い方ということになる。 - フィルタリング条件:
$$
L_i^- - L_i^+ \geq \tau_f
$$
APIを呼んだ場合とそうでない、もしくは、読んだがダメだった場合と比較して、損失が一定以上下がっていれば、APIを呼ぶ処理をデータセットに含める
- 応答付きの損失:
ステップ4: ファインチューニング (Model Fine-Tuning)
-
フィルタリング後のデータセット $C^*$ を生成:
- 元のデータ $C$ に有効なAPIコールを挿入:
$$
x^* = x_{1:i-1}, e(c_i, r_i), x_{i:n}
$$
- 元のデータ $C$ に有効なAPIコールを挿入:
-
$C^*$ を用いてモデル $M$ をファインチューニング:
- 言語モデリングの目的関数を適用
ステップ5: 推論 (Inference)
通常のデコーディングを実行し、APIコールが必要と判断された場合:
- 特定のトークン $\to$ が生成された際、デコーディングを一時停止
- 適切なAPIを呼び出し、応答を取得して再度デコーディングを続行