LoginSignup
2
2

埋め込みベクトルから文章をデコードするには?

Posted at

先日、Kaggleで開催されたLLM Prompt Recorveryコンペでブロンズメダルを獲得することができました!
そこではLLMに関する知見が色々得られたのですが、ちょっと自分の理解が足りていない点が多々あり、それによってスコアが伸び悩みました。
この記事では、LLM(というかより広く自然言語処理モデル)で予測したembedding空間上の文章を自然言語に復元するときに、具体的にどのような処理が行われているのかを整理します。

簡単なコンペの概要

このコンペでは、

  • original text
  • rewritten text

が与えられていました。
ここでrewritten textは、original textをLLMのGemmaを使ってあるプロンプト(rewrite prompt)で書き換えたときの文章です。
このrewrite promptを推測するのがこのコンペの目的であり、自分の推測したプロンプトと正解のrewrite promptがどれくらい似ているのかを競いました。
評価指標はSharpend Cosine Similarity(以下、scs)を使い、推測したプロンプトと正解のrewrite promptの埋め込みベクトルの類似度を測りました。
ここで埋め込みベクトルにする際には、SentenceTransformerが使われました。

大まかなアプローチ

このコンペにおいてスコアを上げていくための大まかな方法は、「平均的なプロンプトを作る」ということでした。
LLMにrewrite promptを予測させておいて、そこに平均プロンプトを足し合わせるとscsが上昇する傾向にありました。

やりたかったけどできなかったこと

最初に思い描いていた実現方法

私としては、考えうるプロンプトや、データセットを公開してくださっていた方々のプロンプトを収集して、それらをすべて埋め込みベクトルにして、それらの平均をとったものを自然言語に直そうというアイデアを思いつきました。

つまり、
①収集したプロンプトをそれぞれencode

②encodeして埋め込みベクトルになったものたちの平均をとる

③平均をとったベクトルをdecodeして自然言語を得る

ということをやろうとしました。
実際、コンペ終了後に上位陣の解法を確認すると、だいたい上と同じようなアイデアを使っており、方向性としては間違っていなかったのだなと思いました。

問題点

①と②はできましたが、③を行うことができませんでした。
文章を埋め込みベクトルにすることができるのなら、埋め込みベクトルを文章にすることもできると思い込んでいたのですが、どうやらそれはできないようでした。
つまり、
「文章をencodeして埋め込みベクトル空間にembeddingすること」は可能だが、
「埋め込みベクトル空間の中のあるベクトルを取り出してそれを文章にすること」は不可能のようでした。

どうやるべきだったか?

③で平均プロンプトに対応する自然言語の文章を得るには工夫が必要でした。
一言でその方法をざっくり表現すれば、トークンを色々組み合わせていき、平均プロンプトに近似していくという方法です。
より具体的には、平均プロンプトの埋め込みベクトルに対して、トークンを色々組み合わせて作った文章の埋め込みベクトルのscsを計算して、そのscsが高いものを得るというものです。

具体的にどのようなコードになるかというと以下のようになります。

import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer

model = SentenceTransformer("sentence-transformers/sentence-t5-base")

# 収集したプロンプトをそれぞれencodeして埋め込みベクトル化を行い、その各成分の平均をとった「平均埋め込みベクトル」を作る。
embedded_prompts = [model.encode([prompt]) for prompt in all_sample_prompts]
mean_embedding = np.mean(embedded_prompts, axis=0)

# 改善していくテキスト
current_best_str = "Please improve this text using the writing style with maintaining the original meaning but altering the tone."

def generate_new_str(current_best_str: str, **kwargs) -> str:
  """ 「平均プロンプト」を獲得するため探索を行うための関数。 """

  # 埋め込みを行うモデルで使われている単語の一覧を取得する。
  index_to_token = {v: k for k, v in model.tokenizer.get_vocab().items()}

  new_str = ""
  """
  ここでcurrent_best_strから新しいテキストを生成する。
  具体的には単語を足したり除いたり入れ替えたり置き換えたりする。
  新しいテキストの生成方法はヒューリスティックな方法になる。(後述)
  """
  return new_str

while True:
    # イテレーションを行なって「平均プロンプト」に近似していく。
    new_str = generate_new_str(current_best_str)
    
    new_str_embeddeing = model.encode([new_str])
    current_best_str_embedding = model.encode([current_best_str])
    if cosine_similarity(new_str_embeddeing, mean_embedding) > cosine_similarity(current_best_str_embedding, mean_embedding) :
        # より「平均埋め込みベクトル」とのscsが大きいものを暫定的にベストな「平均プロンプト」にする。
        current_best_str = new_str

なお、上記の平均プロンプトを探索するコードについては、以下のKaggleのディスカッションを参考にしました。
https://www.kaggle.com/competitions/llm-prompt-recovery/discussion/494499

また、generate_new_str関数では、テキストをベクトル空間に埋め込むためのモデルで使うコーパス(単語集)で新しいテキストの生成を行います。
この新しいテキストの生成はヒューリスティックな方法で行います。
総当たりで最適な文章を探し出したら良いではないかと思われるかもしれないですが、それだと実行時間が非常に長くかかってしまい現実的ではありません。
例えば今回のコンペで使われたSentenceTransformerのトークナイザのトークン数はmodel.tokenizer.vocab_sizeで得られますが、32100もの大きな値になります。
次にどの単語が来るのかを総当たりで試そうとしたら、
32100 * 32100 * 32100 * ...通りもの天文学的な数のパターンになってしまいます。
そこで、ベストな解よりはベターな解を求める方針にして、現実的な方法でより良いテキストを探していくことになります。
ビーム法などいくつか方法があるのですが、具体的な方法は以下のページに綺麗にまとまっています。
https://zenn.dev/hellorusk/articles/1c0bef15057b1d

勘違いしていたことや混乱していたことの個人的なまとめ

  • model.encodeとmodel.tokenizer.encodeを混同してしまっていた
    • model.encodeは埋め込みを行うものである
    • model.tokenizer.encodeはトークン化を行うものである
    • modelはモデル、model.tokenizerはトークナイザなのだから、言われてみれば当たり前のことではあるが、慣れないことをしていたため気づけなかった...
  • 埋め込んだベクトルから自然言語に変換するデコードを簡単にできると思っていた
    • メソッドひとつでこのようなデコードができると思っていた
    • 原理をよく考えてみると、文章をトークンに分けた段階で離散的になっているため、連続的なベクトル空間からピックアップしたあるベクトルが、離散的なトークン列にピッタリ対応することはまずない

個人的なコンペの反省

上位層の解法をみてみると、自分でも試してみて、早い段階でうまくいかないと思って諦めてしまっていたものがいくつもありました。
そこをしっかりとやり切ることができたら余裕を持ってシルバーメダル圏内に届いていたと思います。
こうした「やり切れなさ」は余計なところに認知の負担がかかることによって起こっているような気がしています。
つまり、見慣れぬライブラリの使い方や、読みにくいコードの見た目などによって、より考えたいことに考える力を割けていないということです。
初歩的なことではありますが、Pythonの基本文法やライブラリの使用への習熟、それに加えて読みやすいコードの作成法などを徹底して身につける必要性を感じました。

2
2
2

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2