2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

コード特化型のModernBERTモデル(CodeModernBERT-Owl)を開発した話

Posted at

はじめに

はじめまして。私は情報工学を専攻している学部生です。
個人的にTransformerやGPTなどの深層学習を用いた自然言語処理について学習していて、ある程度成果が出てきたのでまとめておこうと思い、この記事を執筆しています。

TransformerやGPTを使った深層学習モデルの研究が進む中、特にコード理解に特化したモデルであるCodeModernBERT-Owlの開発に成功しました。しかし、その技術的背景や活用方法が理解されにくいため、本記事ではその詳細と工夫点を紹介し、同様の研究を目指す学生や技術者の助けとなることを目指しています。

またモデルをHugging Face上(LLMやAIにとってのGitHubのようなプラットフォーム)にアップロードしています。
CodeModernBERT-Owl - Hugging Face

Transformerアーキテクチャの基本構造については本題とは異なるため、詳細な説明は省略します。理解を深めるためには、別の資料やサイトで図解を参照することをおすすめします。

まずは類似している技術であるCodeBERTについて解説しますが,技術的な内容を多く含んでいるため少し読みづらいかもしれません.

CodeBERTとは

CodeBERTは、Microsoft社が開発したモデルであり、RoBERTaアーキテクチャをベースにしてコードと自然言語の関連性を学習するように設計されているモデルです.つまりコードや自然言語について理解し,要約(実際は埋め込みベクトル)を作成します.

また,大規模なコードと自然言語のペアデータを用いて事前学習されており、以下のようなタスクに利用できます。

  • コード検索 (Code Search): コードスニペットを自然言語のクエリから検索する
  • コード補完 (Code Completion): 与えられたコード片から次のトークンを予測する
  • コード要約 (Code Summarization): コードの意味を要約し、自然言語で説明する
  • コード分類 (Code Classification): コードを特定のカテゴリに分類する

勘違いしやすい点であるため補足しておくとCodeBERT単体でこれらができるわけではなく,Seq2Seqのような形や,新しいヘッド(分類用やコード補完用の出力をする部分)をつけたりして実装します.いずれにしてもCodeBERTの双方向アテンションを用いて作成された高品質なコード埋め込みを分析し,利用することで実現しています.

より具体的に言うとコード検索タスクでは、CodeBERTの出力に類似度計算層を追加(またはクエリとコードの埋め込みをそれぞれ別で取得し,その類似度を計算)し、クエリとの関連度を算出します。コード補完タスクでは、CodeBERTの出力にデコーダ層を追加し、次のトークンを生成します。

CodeBERTの技術的特徴

CodeBERTはRoBERTaと同様の事前学習手法を採用していますが、学習データとしてCodeSearchNetというデータセットを用いて以下のようなタスクを用いて訓練されています.

  1. マスク言語モデリング (MLM: Masked Language Modeling)

    • ソースコード内の一部のトークンをマスクし、それを予測するタスク
  2. 置き換えられたトークンを識別するタスク (Replaced Token Detection, RTD)

    • ある単語が意図的に置き換えられた場合、それを識別するタスク

トークナイザの役割と重要性

トークナイザは、自然言語やコードを深層学習モデルが処理できる形に変換する重要な役割を担います。具体的には、テキストを単語や記号などの単位(トークン)に分割し、それぞれのトークンにIDを割り当てることで、数値ベクトルに変換します。

コード特化のトークナイザの利点

通常のトークナイザは、自然言語の特性に合わせて設計されています。一方、コードは自然言語とは異なる特性(記号の多用、特定の構文など)を持つため、通常のトークナイザでは効率的に処理できません。

コード特化のトークナイザは、コードの特性に合わせて設計されており、以下のような利点があります。

  • コードの構造をより正確に捉える:
    • コード内の記号やキーワードを適切にトークンとして分割することで、コードの構造をより正確に捉えることができます(つまりコードに頻出するセミコロン(;)やダブルクォーテーション("")など,英語などの自然言語に特化したトークナイザには含まれていない可能性がありUnknownTokenに置き換えられてしまう恐れがあります)
  • 語彙サイズの効率化:
    • コードに特有の単語や記号を効率的にトークン化することで、語彙サイズを削減し、モデルの学習効率を向上させることができます
  • コードの意味をより適切に表現:
    • コードの構文や意味を考慮したトークン化により、コードの意味をより適切に数値ベクトルとして表現できます.(特にCodeBERTは空白や改行を一度別の文字に置き換え,デコードする際戻すため,モデルとしては空白や改行を認識でき,ユーザ側も構造を損なっていないので,わかりやすいコードを維持できます)

CodeBERTにおけるトークナイザ
CodeBERTでは、コードの特性を考慮したトークナイザを使用することで、コードの構造や意味をより正確に捉え、コード関連タスクの性能向上に貢献しています。

CodeBERTは、Python、Java、JavaScript、PHP、Rubyなど複数のプログラミング言語に対応しており、自然言語とコードの間の関係を学習することで、プログラムの理解やコード生成タスクに有効活用されています。

CodeBERTの限界と改善点

CodeBERTは、コードと自然言語の関連性を学習する強力なモデルですが、いくつかの限界や改善点を抱えています.

  • 複雑(長い)なコードの理解:

    • CodeBERTは、比較的短いコードスニペットの理解には優れている反面、特にシーケンス長(モデルの読み込める長さ)が512トークンしかなくそれ以上になってしまうと入力部分がコメントのみになってしまい、うまく学習できなかったり、コードの特徴をとらえたコード埋め込みを出力できない恐れがある
  • ファインチューニングの必要性:

    • CodeBERTは、様々なタスクに利用できますが、それぞれにファインチューニングして適応させる必要があります
  • ModernBERTアーキテクチャの発表:

    • 2024年12月頃にModernBERTというこれまでのBERTやRoBERTaの改良されたアーキテクチャが発表されました.(https://huggingface.co/blog/modernbert)
    • このアーキテクチャは従来のBERTシリーズでは苦手だった長いシーケンス長に対応しやFlashAttentionを採用することで計算効率を上げています. (最初期のモデルは3epochで10時間程度,十分収束していました)
    • つまりCodeBERTに採用されているアーキテクチャの弱点である 短いシーケンス長を改善できる可能性があるということです
    • またほかにもそれぞれの企業や組織が独自にBERTの改良版を開発していたりします(Salesforce社のCodeXEmbedなど)

CodeXEmbedSalesforce/SFR-Embedding-Code-400M_Rについて
こちらも私のモデルと同じくコード特化のBERTを改良したモデルで,特にコード検索において非常に高い性能を持っています。このモデルはSalesForce社が独自に作成したアーキテクチャを用いていて,一番小さいモデルで400Mあり,汎用的な学習からだんだんコード検索に特化した学習へ移行していくことで精度を上げているようです。
また,個人的に興味深いと思った点がBPEではなくWordPieceでトークナイザを用いている点です。

FlashAttentionやRoPEについて
具体的な詳細は他の記事をご覧ください,特にFlashAttentionやRoPEはそれぞれLlamaのような近年のLLMに採用されているアーキテクチャであり,それぞれ理解しておくと効率のいいモデル作成に役立つと思います.またModernBERTの特徴である,シーケンス長が長くなっても前後の文脈をとらえられるようにするLocalAttentionも興味深いアイデアです。

CodeModernBERT-Owl

CodeModernBERT-OwlはModernBERTアーキテクチャとCodeSearchNetデータセット,code_x_glue_ct_code_to_textデータセット,またMIT,Apache-2.0,BSD 2-Clause License,BSD 3-Clause Licenseのライセンスで配布されているGitHubで公開されているソースコードからcodesearchnetのように関数とそのdocstringを抽出したShuu12121/java-codesearch-dataset-openShuu12121/rust-codesearch-dataset-openのデータセットを用いて,一から事前学習を行いました.(厳密には私がこのモデルを完成させる前に開発していた,Codesearchnetのデータセットを用いたモデルであるShuu12121/CodeMorph-ModernBERT-BPE-1.0に対して複数回継続学習を行っています.ほかにも派生モデルが多数存在していますが現時点で最高性能のものであるCodeModernBERT-Owlについてご紹介させていただきます.)

データセットの選定理由
今回選んだ,または作成したデータセットはどれもコードとDocstring(コードを簡潔に説明したコメント)が存在しているものを採用しています。そのためcode_x_glue_ct_code_to_textデータセットの本来の目的とは異なりますが学習データ、評価データとして採用しています.

CodeModernBERT-Owlの特徴

特に以下の3点が大きなポイントです。

  • 長いシーケンスに対応(最大2048トークン)(Microsoftの512トークンモデルと比較して優位性あり)
  • コード検索、コード理解、コードクローン検出に最適化
  • マルチ言語対応Python、PHP、Java、JavaScript、Go、Ruby、Rust

CodeModernBERT-Owlの実験結果

CodeModernBERT-Owlの基本情報についてはぜひモデルカードをご覧ください。
また、以下に示す表はcode_x_glue_ct_code_to_textデータセットで以下の手順で実験した結果です.

  • データセットを準備しそこからクエリ(Docstring)1個とコードを100個ランダムで取得する
  • モデルの種類に応じて埋め込みを計算(この時,どのデータセットにもDocstringがコメントとしてコード内に存在しているので削除してから計算)
  • クエリとコードからコサイン類似度を取り順位を計算
  • それぞれの評価指標で計算

 また評価指標にはMRRを用いて判断しました。

MRRについて
MRRは、検索タスクにおいて、正解となるコードが上位にランク付けされているほど高い値となる指標です。1に近いほど高性能であることを示します。(計算方法は順位の逆数の平均です)
つまり 1位→1, 2位→1/2, 3位→1/3 みたいなイメージです。
実験ではプールサイズ100(ランダムな100個のうち、正解は何番目とされるか)で計算しています。

ファインチューニングについて
今回はモデルそのものの性能を評価するためにファインチューニングについては実施していません。
実際にはSentenceTransformerを用いてファインチューニングするとより性能が上がる可能性があります。

プールサイズ100におけるMRR(正解の順位の逆数の平均)での比較

image.png

言語 / Language CodeModernBERT-Owl CodeHawks-ModernBERT Salesforce CodeT5+ Microsoft CodeBERT GraphCodeBERT
Python 0.8793 0.8551 0.8266 0.5243 0.5493
Java 0.8880 0.7971 0.8867 0.3134 0.5879
JavaScript 0.8423 0.7634 0.7628 0.2694 0.5051
PHP 0.9129 0.8578 0.9027 0.2642 0.6225
Ruby 0.8038 0.7469 0.7568 0.3318 0.5876
Go 0.9386 0.9043 0.8117 0.3262 0.4243

実験の結果、CodeModernBERT-Owlはすべての言語において既存のモデルを上回る検索精度を達成しました。特にGoやPythonでは顕著な改善が見られました。

CodeModernBERT-Owlの工夫点

CodeModernBERT-Owlに至るまで数々の工夫をしてきました.(時系列は上から順番になっています。名前ややこしくて申し訳ありません。適当にかぶらないようにつけていました…)具体的には

  • CodeBERTを通常のBERTで再現実装(Pythonのみに対応) Shuu12121/CodeMorph-BERT
  • 限られた計算資源でのCodeModernBERTモデルの実現(ModernBERTでの実装に変更,もともとシーケンス長を2048トークンに調整,トークナイザを6言語に対応など)Shuu12121/CodeMorph-ModernBERT
  • モデルのパラメータを調整(公式の実装で用いられていたものは採用せず)
  • トークナイザをSentencepiece→BPEに変更Shuu12121/CodeMorph-ModernBERT-BPE-1.0
  • 単純なランダムMLMのみではなく継続学習の方法を変更Shuu12121/CodeHawks-ModernBERT
  • 既存のデータセットだけでなく独自にデータセットを作成し,本来未対応であったRustへの対応,Javaでの精度改善Shuu12121/CodeModernBERT-Owl

このように,継続的なモデル改善を行い高い精度を実現しました.
特にトークナイザをSentencepieceで限界を感じ,よりコードに近いトークナイズをするBPEに変更した後,極端に精度が落ちましたが(下の表参考),新たな学習方法を試してみたところ何とか改善し,元々得意だったPython以外の言語(特にGo)でも精度が出るようになりました.

モデル/言語 Python Java JavaScript php ruby go
CodeMorph-ModernBERT-BPE-1.0 0.411 0.3627 0.2184 0.3998 0.2436 0.2495

コード検索以外でも

CodeModernBERT-Owlの実装をしているとき主にコード検索のタスクをメインに評価してきましたが,コードクローンの発見のタスクも行いました.CodeX-GLUE CC Clone Detection Big Clone Bench
その結果は以下に示す表のとおりであり,かなり高い精度で判別することができていました。

指標
精度 (Accuracy) 0.9883
F1スコア (F1 Score) 0.9571

また,現状試せていませんがCodeBERTがコード要約やコメント生成などで優れた結果を出していたことから,CodeModernBERT-Owlでも同様に実現できるのではないかなと考えています.
特にCodeModernBERT-Owlから得られるコードとそれに対応する自然言語の埋め込み結びつきが強いことはコード検索において既に示されているため、小さいモデルでもより手軽にコメント推薦に役立てられる可能性があります.

最後に

 ここまで読んでいただいてありがとうございました.
 今後の課題として、理論と実践のギャップを埋める必要があると感じています。
また,コード要約やコメント生成といったタスクにおいても本モデルが有効かどうか検証する予定です。また、評価手法の妥当性についても再考する必要があります。(特に外部からの評価がない状態なので妥当性や再現性のある方法を模索しています。)
 個人的にこのモデルが有効に使われる可能性があるのは組み込みやIDEで駆動するような形を想像しています.具体的に検証したわけではないのですが,精度のみ突き詰めた場合はより大きいパラメータを持つモデルには勝てないだろうと思います.しかしこのモデルは小さく、精度がそれなりに出ている。また小さいからこそコードクローンやコードサーチなど様々な使い方やデータ,環境に適応させることがより簡単になると思っています。
最後に今後の展望として、モデルの軽量化と精度向上を両立する研究が求められています。CodeModernBERT-Owlがその一助となることを願いつつ、さらなる改良を重ねていきたいと考えています。

デモ

まずTransfomrsを4.48.0以上にバージョンを上げてください

pip install -U transformers>=4.48.0

次にライブラリをインポートします

from transformers import AutoTokenizer, ModernBertForMaskedLM
import torch

さらにモデルをダウンロードします

# モデル名
repo_name = "Shuu12121/CodeModernBERT-Owl"
# Hugging Face からモデルをロード
model = ModernBertForMaskedLM.from_pretrained(repo_name)
tokenizer = AutoTokenizer.from_pretrained(repo_name)
print("モデルのロード成功!")
print(model)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"使用デバイス: {device}")
model.to(device)

次に埋め込みを取得します

import torch

def get_embedding(text, model, tokenizer, device="cuda"):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
    # token_type_ids があれば削除する
    if "token_type_ids" in inputs:
        inputs.pop("token_type_ids")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    outputs = model.model(**inputs)
    embedding = outputs.last_hidden_state[:, 0, :]
    return embedding

code = "def my_function(): pass"
embedding = get_embedding(code , model, tokenizer)
print(embedding.shape)

取得する埋め込みを変更したい場合

code = "def my_function(): pass"

の部分を適宜変更してください

参考文献

2
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?