12
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

NRI OpenStandiaAdvent Calendar 2024

Day 19

AWS Trainiumを利用してLlama 3.1 8Bの継続事前学習をしてみた

Last updated at Posted at 2024-12-18

はじめに

AWS Trainiumを使ったLlama 3.1の継続事前学習に携わる機会をいただいたので、
Trainiumの概要や継続事前学習の手順について解説したいと思います。
また、学習を行う際に躓いた点があったため、都度解説します。

AWS Trainiumとは

AWS Trainiumとは、AWSが提供する深層学習を行うための機械学習アクセラレータです。

Llamaなどの大規模言語モデルの学習には、GPUなどの並列計算に強い機械学習アクセラレータが必要不可欠です。

Trainiumは大規模言語モデルのトレーニングに最適化されており、AWS上でTrainium搭載のTrn1インスタンスとして提供されています。

従来のGPUインスタンスと比べたメリット

コスト面でメリットがあると考えられます。

AWSの公式サイトでは、「他の同等の Amazon EC2 インスタンスと比べて、トレーニングコストを最大 50% 削減」とあります。
この「同等の Amazon EC2 インスタンス」が何かははっきり言及されていませんが、おそらくP4de.24xlargeインスタンスを指しているものと思われます。このインスタンスとの比較をしてみます。
(価格はOregonリージョン。2024年12月10日現在)

Instance Size GPU / Trainium アクセラレータ メモリ vCPU メモリ アクセラレータ P2P BW オンデマンド価格* (USD/時間)
P4de.24xlarge 8 640 GB 96 1152 GiB 600 GB/s 40.96
Trn1.32xlarge 16 512 GB 128 512 GB 768 Gbps 21.50
Trn1n.32xlarge 16 512 GB 128 512 GB 768 Gbps 24.78

単純な価格比較では50%削減に至りません。しかし、TrainiumにはサポートするAWS Neuron SDKなど、AWS側がTrainiumに機械学習の高速化に対する工夫を多数施しているため、学習時間を考慮すれば「最大 50% 削減」できるとAWSは謳っているようです。

P5インスタンス、Trn2インスタンスの登場

記事を書いている最中に、AWSからP5インスタンス・P5enインスタンスTrn2インスタンスが登場しました。

P5およびP5enインスタンスは、それぞれNVIDIAのH100、H200を搭載したインスタンスです。前世代(おそらくP4)のインスタンスと比較してトレーニングコストを40%削減できるそうです。

Trn2インスタンスは、EC2 P5e および P5en インスタンスよりも 30 〜 40% 優れたコストパフォーマンスを提供しているそうです。

詳しく書くと記事が長くなりすぎるため、今回は割愛します。

全体アーキテクチャ図

全体アーキテクチャ図を以下に示します。

image.png

主要なリソースの役割は以下の通りです。

Head Node

  • クラスター全体の制御・管理を行い、Worker nodeへのタスク分配やスケジューリングを担当します。
  • 学習の進捗管理やWorker node間の同期、モデルの統合なども行います。

Worker Nodes

  • GPUを使用して実際の学習処理や計算処理を実行します。
  • バッチ処理やデータの前処理、モデルパラメータの更新、モデル形式の変換を担当します。

FSx

  • 学習データやモデルチェックポイントを保存する高速なストレージシステムです。
  • 複数のノード間でデータを共有し、並列読み込みや低レイテンシでのアクセスを可能にします。

学習実行までの手順

ここからの手順は

警告
以降の操作には数百ドルのコストがかかります。
実行する際はリソースの消し忘れにご注意ください(特にFSx周り)。

Trn1インスタンスのサービスクォータの引き上げ

Trn1インスタンスにはAWSリージョンごとに利用可能なリソースの最大値(サービスクォータ)が定められています。
デフォルトでは0(=利用できない)となっているため、この値を引き上げます。
ここではクォータの単位がvCPUとなっています。
Trn1n.32xlarge インスタンスのvCPU数は128であり、以降の手順では4台のインスタンスを利用するため、128×4=512 以上の値に設定します(画像では1024に設定されています)。

image.png

VPCの準備

デフォルトのVPCを利用してもよいですが、リソースの簡潔な管理のために新しくVPCを作成します。
このページの手順を参考に作成しました。
AZはTrn1が利用できるゾーンを指定する必要があります。私はus-west-2dを指定しました。

また、VPC作成後、VPC内のパブリックサブネットにパブリックIPv4アドレス自動割り当て設定を適用します。
この設定により、後ほど実行するpclusterコマンドで、Head NodeにパブリックIPアドレスが付与されSSHログインが可能になります。

Parallel Clusterの構築

AWS CloudShell上で以下のコマンドを入力し、AWS ParallelClusterをインストールします。
AWS公式のParallelClusterインストール手順を参考にしました。

python3 -m pip install --upgrade pip
python3 -m pip install --user --upgrade virtualenv
python3 -m virtualenv ~/apc-ve
python3  -m virtualenv -p $(which python3) ~/apc-ve
source ~/apc-ve/bin/activate
python3 -m pip install --upgrade "aws-parallelcluster"
curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.38.0/install.sh | bash
chmod ug+x ~/.nvm/nvm.sh
source ~/.nvm/nvm.sh
nvm install --lts
node --version

インストールされたか確認します。バージョンが出力されたらOKです。私の場合は{"version": "3.11.1"}でした。

pcluster version

Head Node、Worker Node、FSxを作成するためのlaunch.yamlファイルを作成します。
以下のタグを、自身が作成したリソースに基づいて置き換えます。

  • <PUBLIC SUBNET ID>: VPC作成時に作成したパブリックサブネットのID
  • <PRIVATE SUBNET ID>: VPC作成時に作成したプライベートサブネットのID
  • <KEY NAME WITHOUT .PEM>: SSHに用いるpemキー名(拡張子.pemは含まない
tee launch.yaml > /dev/null << EOF
Region: us-west-2 # 対象リージョン 
Image:
  Os: ubuntu2004
HeadNode:
  InstanceType: r6a.12xlarge #ヘッドノードとして使用する EC2 インスタンスタイプ
  Networking:
    SubnetId: subnet-<PUBLIC SUBNET ID>
  Ssh:
    KeyName: <KEY NAME WITHOUT .PEM> 
  LocalStorage:
    RootVolume:
      Size: 1024
  CustomActions:
    OnNodeConfigured:
      Script: s3://neuron-s3/pcluster/post-install-scripts/neuron-installation/v2.20.1/u20/pt/install_neuron.sh
  Iam:
    S3Access:
       - BucketName: neuron-s3
         EnableWriteAccess: false
Scheduling:
  Scheduler: slurm
  SlurmQueues:
    - Name: compute1
      CapacityType: ONDEMAND
      ComputeSettings:
        LocalStorage:
          RootVolume:
            Size: 1024
          EphemeralVolume:
            MountDir: /local_storage
      ComputeResources:
        - Efa:
            Enabled: true
          InstanceType: trn1.32xlarge
          MaxCount: 4
          MinCount: 4
          Name: queue1-i1
      Networking:
        SubnetIds:
          - subnet-<PRIVATE SUBNET ID>
        PlacementGroup:
          Enabled: false
      CustomActions:
        OnNodeConfigured:
          Script: s3://neuron-s3/pcluster/post-install-scripts/neuron-installation/v2.20.1/u20/pt/install_neuron.sh
      Iam:
        S3Access:
          - BucketName: neuron-s3
            EnableWriteAccess: false
SharedStorage:
- FsxLustreSettings:
    DeploymentType: SCRATCH_2
    StorageCapacity: 1200
  MountDir: /fsx
  Name: pclusterfsx
  StorageType: FsxLustre
EOF

クラスターのデプロイを行います。

pcluster create-cluster --cluster-configuration launch.yaml -n cluster-test

出力されているjsonに"cloudformationStackStatus": "CREATE_IN_PROGRESS"があれば、作成自体は開始しています。内部的にはCloudFormationを使ってリソースを作成しているため、作成の進捗はマネジメントコンソールから追うことができます。かなり時間がかかりますが(自分は25分程度かかりました)、気長に待ちましょう。

pclusterコマンドの利用法について補足します。
クラスターのステータスは以下のコマンドでも確認することができます。

pcluster describe-cluster --cluster-name cluster-test

また、Worker Nodeを停止させるには以下のコマンドを実行します。1インスタンスあたり$21.5/hと高額なインスタンスであるため、こまめに停止することをおすすめします。

pcluster update-compute-fleet --cluster-name cluster-test --status STOP_REQUESTED

Worker Nodeを起動させるには以下のコマンドを実行します。

pcluster update-compute-fleet --cluster-name cluster-test --status START_REQUESTED

必要なファイル・データセットの準備

SSHでHead Nodeに接続します。pemキーの実行権限がowner以外に付与されているとエラーが出て接続できないため、権限を変更しておく必要があります。

chmod 700 <KEY NAME>.pem
ssh -i <PATH_TO_YOUR_PEM_KEY> ubuntu@<PUBLIC_IP>

Neuron SDKがダウンロードされた環境をactivateします。

source ~/aws_neuron_venv_pytorch/bin/activate

作業用のディレクトリを作成し、その配下に必要なスクリプト群をダウンロードしておきます。

mkdir -p ~/examples/tp_zero1_llama_hf_pretrain
cd ~/examples/tp_zero1_llama_hf_pretrain

wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/master/examples/training/llama/requirements.txt
wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/master/examples/training/llama/tp_zero1_llama_hf_pretrain/tp_zero1_llama_hf_pretrain.py
wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/master/examples/training/llama/tp_zero1_llama_hf_pretrain/logger.py
wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/master/examples/training/llama/training_utils.py
wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/master/examples/training/llama/modeling_llama_nxd.py
wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/master/examples/training/llama/get_dataset.py
wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/master/examples/training/llama/tp_zero1_llama_hf_pretrain/tp_zero1_llama3_8B_hf_pretrain.sh

requirements.txtを利用して、依存するパッケージのインストールを行います。

python3 -m pip install -r requirements.txt

Llama3.1 8BのConfigファイルをダウンロードします。

mkdir 8B_config_llama3 && cd 8B_config_llama3
wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/master/examples/training/llama/tp_zero1_llama_hf_pretrain/8B_config_llama3.1/config.json
cp config.json ../
cd ..

tokenizerをダウンロードするスクリプトを作成し、実行します。
'your_own_hugging_face_token'は自身が取得したトークンに置き換えてください。

tee ./get_tokenizer.py > /dev/null << EOF
from huggingface_hub import login
from transformers import AutoTokenizer

login(token='your_own_hugging_face_token')
tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-3.1-8B')
tokenizer.save_pretrained(".")
EOF

python3 get_tokenizer.py

学習データをダウンロードし、tokenizeを実行します。実行後、データセットをFSxに移動します。

cd ~/examples/tp_zero1_llama_hf_pretrain
python3 get_dataset.py --llama-version 3
mv ~/examples_datasets /fsx/

モデルのダウンロード・変換

初期ウェイトをダウンロードするスクリプトを作成・実行します。
'your_own_hugging_face_token'は自身が取得したトークンに置き換えてください。

cd ~/examples/tp_zero1_llama_hf_pretrain
tee ./save_model_weights.py > /dev/null << EOF
from transformers import AutoConfig, AutoModelForCausalLM
from huggingface_hub import login
import torch
login(token='your_own_hugging_face_token')
config = AutoConfig.from_pretrained('./config.json')
model = AutoModelForCausalLM.from_pretrained(
  "meta-llama/Llama-3.1-8B",
  config=config
)
torch.save(model.state_dict(), 'llama-31-8b.pt')
EOF
python3 save_model_weights.py

HuggingFace形式からNxD形式に変換するためのスクリプトをダウンロードします。

wget https://raw.githubusercontent.com/aws-neuron/neuronx-distributed/master/examples/training/llama/convert_checkpoints.py

変換後のウェイトの保存先ディレクトリを作成します。

mkdir -p checkpoint/pretrained_weight

変換を行います。Head NodeからWorker NodeのいずれかにSSHアクセスして、環境をactivateします。

ssh compute1-st-queue1-i1-1
source ./aws_neuron_venv_pytorch/bin/activate
cd ~/examples/tp_zero1_llama_hf_pretrain/

変換スクリプトを実行します。実行に5分程度かかります。

python3 convert_checkpoints.py --tp_size 32 --pp_size 1 --n_layers 32\
  --input_dir llama-31-8b.pt\
  --output_dir checkpoint/pretrained_weight\
  --convert_from_full_state --save_xser 1\
  --kv_size_multiplier 4 --qkv_linear 1 --config config.json

変換完了後、ウェイトをFSxボリュームにコピーし、Head Nodeに戻ります。

cp -r checkpoint/pretrained_weight /fsx/checkpoint/pretrained_weight
exit

学習スクリプトの編集

ダウンロードした学習スクリプトtp_zero1_llama3_8B_hf_pretrain.shに以下の修正を加えます。

  • データセットのパスを指定します。
DATA_PATH="/fsx/examples_datasets/wikicorpus_llama3_tokenized_8k"
  • 学習途中でウェイトを保存できるように、CHECKPOINT周りの設定を行います。
    (2ステップで学習を終わらせるため、CHECKPOINT_FREQは1に設定しました。)
CHECKPOINT_FREQ=1 # 10から1に変更
CHECKPOINT_DIR=/fsx/checkpoint
  • torchrunの箇所に、以下のオプションを付け加えます。
    --checkpoint_freq $CHECKPOINT_FREQ \
    --checkpoint_dir $CHECKPOINT_DIR \
    --pretrained_weight \

チェックポイントの保存先ディレクトリを作成します。

mkdir /fsx/checkpoint

また、今回は継続事前学習の完了を確認したいだけであるため、
2ステップで継続事前学習を終わらせるようにします。

if [ $NEURON_EXTRACT_GRAPHS_ONLY -gt 0 ]; then
    STEPS_THIS_RUN=2
    OUTPUT_LOG=log_compile-$NODE_ID.log
elif [ -v PERF_TEST ] && [ $PERF_TEST -gt 0 ]; then
    STEPS_THIS_RUN=100
    OUTPUT_LOG=log_exe-$NODE_ID.log
else
    STEPS_THIS_RUN=2 # ここの値を-1から2に変更する
    OUTPUT_LOG=log_exe-$NODE_ID.log
fi

学習ジョブの実行

モデルをコンパイルします。実行に10分程度かかります。

sbatch --exclusive --nodes 4\
  --cpus-per-task 128\
  --wrap="srun neuron_parallel_compile bash $(pwd)/tp_zero1_llama3_8B_hf_pretrain.sh"\
  --output=/fsx/slurm_out/slurm-%j.out

実行ステータスは、squeueコマンドを使って確認することができます。

squeue

学習ジョブを実行します。

sbatch --exclusive --nodes 4\
  --cpus-per-task 128\
  --wrap="srun bash $(pwd)/tp_zero1_llama3_8B_hf_pretrain.sh"\
  --output=/fsx/slurm_out/slurm-%j.out

スクリプトを実行し、ログを確認します。しばらくするとlossの値が出力され、学習が進んでいることが分かります。

$ cat /fsx/slurm_out/slurm-6.out | grep step_loss
LOG Mon Dec  9 05:56:18 2024 - (0, 1) step_loss : 2.1128 learning_rate : 1.50e-06 throughput : 4.62 seq/s 
LOG Mon Dec  9 05:59:01 2024 - (0, 2) step_loss : 2.1125 learning_rate : 3.00e-06 throughput : 5.33 seq/s 

継続事前学習済みモデルは、CHECKPOINT_DIRで指定したディレクトリに格納されています。今回は1ステップごとにチェックポイントを保存しているため、/fsx/checkpoint/step_2/model/が最終的な成果物となります。

$ ls /fsx/checkpoint/step_2/model/
dp_rank_00_tp_rank_00_pp_rank_00.pt
dp_rank_00_tp_rank_00_pp_rank_00.pt.info.pt
dp_rank_00_tp_rank_00_pp_rank_00.pt.tensors
...

動作確認を行うために、NxD形式で保存されている継続事前学習済みモデルをHuggingface形式に変換します。変換はWorker nodeで行うため、Worker NodeへのSSHログインが必要です。

ssh compute1-st-queue1-i1-1
source ~/aws_neuron_venv_pytorch/bin/activate
cd ~/examples/tp_zero1_llama_hf_pretrain

変換を実行後、exitコマンドでHead nodeに戻ります。

mkdir -p /fsx/checkpoint_hf/
python3 convert_checkpoints.py --model_style hf\
  --input_dir /fsx/checkpoint/step_2/model/\
  --output_dir /fsx/checkpoint_hf\
  --load_xser True --config config.json\
  --tp_size 32 --pp_size 1 --kv_size_multiplier 4\
  --qkv_linear True --convert_to_full_state

exit

Transformersで読み込めるように、ファイル名をpytorch_model.binにリネームします。

mv /fsx/checkpoint_hf/checkpoint.pt /fsx/checkpoint_hf/pytorch_model.bin

config.jsonなどの各種ファイルを全てpytorch_model.binと同階層にコピーします。

cd ~/examples/tp_zero1_llama_hf_pretrain
cp config.json /fsx/checkpoint_hf/
cp tokenizer.json /fsx/checkpoint_hf/
cp tokenizer_config.json /fsx/checkpoint_hf/
cp special_tokens_map.json /fsx/checkpoint_hf/

動作確認

Worker Node上で動作確認を行います。動作確認に必要なパッケージをインストールします。

pip config set global.extra-index-url https://pip.repos.neuron.amazonaws.com
pip install transformers==4.43.2 tokenizers==0.19.1 sentencepiece
pip install neuronx-cc==2.* torch-neuronx
pip install transformers-neuronx==0.12.313

プログラムを作成し、実行します。
Transformers-neuronxのサンプルノートブックを参考に作成しました。

tee ./llama31-8B-sampling.py > /dev/null << EOF
import time
import torch
from transformers import AutoTokenizer
from transformers_neuronx import LlamaForSampling
from transformers import LlamaForCausalLM, LlamaTokenizer, PreTrainedTokenizerFast
from transformers_neuronx import LlamaForSampling, NeuronConfig, GQA, QuantizationConfig
from transformers_neuronx.config import GenerationConfig 

# Set this to the Hugging Face model ID
model_id = "/fsx/checkpoint_hf/"

neuron_config = NeuronConfig(
                    on_device_embedding=False,
                    attention_layout='BSH',
                    fuse_qkv=True,
                    group_query_attention=GQA.REPLICATED_HEADS,
                    quant=QuantizationConfig(quant_dtype='s8', dequant_dtype='bf16'),
              )

# load meta-llama/Llama-3.1-8B to the NeuronCores with 32-way tensor parallelism and run compilation
neuron_model = LlamaForSampling.from_pretrained(model_id, neuron_config=neuron_config, batch_size=1, tp_degree=32, amp='bf16', n_positions=8192)
neuron_model.to_neuron()

# construct a tokenizer and encode prompt text
tokenizer = AutoTokenizer.from_pretrained(model_id)
prompt = "Hello, I'm a language model and I like to"
input_ids = tokenizer.encode(prompt, return_tensors="pt")

# run inference with top-k sampling
with torch.inference_mode():
    start = time.time()
    generated_sequences = neuron_model.sample(input_ids, sequence_length=8196, top_k=50)
    elapsed = time.time() - start

generated_sequences = [tokenizer.decode(seq) for seq in generated_sequences]
print(f'generated sequences {generated_sequences} in {elapsed} seconds')

EOF

最後に、pythonプログラムを実行します。モデルのロードにかなり時間がかかりますが、推論結果が出力されていればOKです。
今回は2ステップのみの実行ということで、継続事前学習の効果は体感できませんが、Trainiumで継続事前学習を一通り実行できることが確認できたと思います。

python llama31-8B-sampling.py
# 出力
... (ログ出力)
generated sequences ['<|begin_of_text|>Hello, I\'m a language model and I like to discuss about any topic. For that reason, I\'ve prepared a set of questions to help you have an interesting and educational conversation. Do you mind talking with me?\n
...(中略)
and friends in the office.<|end_of_text|>'] in 19.089792013168335 seconds

躓きポイント

tp_zero1_llama3_8B_hf_pretrain.shのソースコードでLlama3.1は動作するのか?

動作します。Llama3.1用のスクリプトもGitHubには存在しますが(tp_zero1_llama31_8B_hf_pretrain.sh)、
内部的には tp_zero1_llama3_8B_hf_pretrain.shを実行しているだけなので問題ありません。

後片付け

Clusterの削除

pclusterコマンドのdelete-clusterを用いてClusterを削除します。

pcluster delete-cluster --cluster-name cluster-test --region us-west-2

VPCの削除

マネジメントコンソールから「VPC」を検索し、作成したVPCを選択し、「VPCの削除」をクリックします。VPC内のリソースが削除されていない場合は削除し、VPCを削除します。

おわりに

  • 今までGPUで学習するのが当たり前だと思っていたため、Trainiumを利用した分散学習でも無事に実行できる点で新鮮さがありました。
  • 学習自体のコストパフォーマンスは優れているかもしれませんが、学習に至るまでのソースコード修正に時間がかかる感覚がありました。
    • NxD周りは開発中のものが多く、情報が少ないため学習を動かすまでの時間がかかってしまいました。
    • ある程度サービスが整ってくると使いやすくなるのではと期待しています。
  • trn2インスタンスが発表されたため、より高速化されることを期待しています!

謝辞

本記事は、AWS Generative AI Innovation Center の支援を受けて執筆いたしました。ご協力いただいたAWSの皆様、ありがとうございました。

参考文献

12
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
12
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?