こんにちは。今回は、GPT-3に基づいて作成されたEleutherAIのGPT-Jをmesh-transformer-jaxを使用して自分の環境で動かしたメモです。(GPT-NeoX-20Bを動かしたメモはこちら)
また、今回は以下の記事にあるように、Windows 11のDocker Desktop環境で動かしてみます。使用しているGPUはNVIDIA RTX 3090です。
TensorFlowのコンテナで環境を整える
今回はNVIDIAのTensorFlowコンテナを使用します。mesh-transformer-jaxをcloneすして作業するためのフォルダをフォルダをD:\work\gpt-jとして作成しています。
TensorFlowコンテナの起動には、Jupyter Notebookを使うためのポート転送設定と、フォルダのマウントを設定しておきます。
docker run --gpus all -it -p 8888:8888 -v D:\work\gpt-j:/gpt-j nvcr.io/nvidia/tensorflow:21.12-tf2-py3 bash
起動したら、mesh-transformer-jaxをcloneします。
cd /gpt-j
git clone https://github.com/kingoflolz/mesh-transformer-jax.git
cd mesh-transformer-jax
この requirements.txt に書かれているtensorflowパッケージは何故かcpu版を指定しているので、普通のtensorflowパッケージを指定するようにrequirements.txtを修正します。あと、cudaに対応したjaxlibをインストールしておきます。
sed -e 's/tensorflow-cpu/tensorflow/' requirements.txt > new_requirements.txt
pip install -r new_requirements.txt
pip install jax==0.2.12
pip install jaxlib==0.1.68+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_PYTHON_CLIENT_ALLOCATOR=platform
次にGPT-J-6Bのスリム版のパラメーターをダウンロードして、展開しておきます。
wget -c https://mystic.the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd
tar -I zstd -xf step_383500_slim.tar.zstd
GPT-Jの動作を確認をする
resharding_example.pyを実行してGPT-Jが動作するか確認します。
sed -e 's/infer("EleutherAI is")/print(infer("EleutherAI is"))/' resharding_example.py > resharding_example2.py
python resharding_example2.py
実行結果(一部抜粋)
completion done in 98.40319633483887s
[' a single player, strategy game ....
一回目の推論は時間がかかりますが、"EleutherAI is"としているので、その後の文字列が何か出力されていれば動作しています。(間違っているかもしれません)
いろいろ試してみる
他のプロンプト試すときはJupyter Notebookの方が便利です。上で作成したresharding_example2.pyの最終行に書かれているinferを書き換えて試してみます。
top_p = 0.9
temp = 1
context = '''私は真実を答える賢い質問応答ボットです。
Q: 日本の人口は?
A: 1.2億人です。
Q: 世界で一番人口が多い国は?
A: '''
print(context)
print(infer(top_p=top_p, temp=temp, gen_len=64, context=context)[0])
実行結果
私は真実を答える賢い質問応答ボットです。
Q: 日本の人口は?
A: 1.2億人です。
Q: 世界で一番人口が多い国は?
A:
completion done in 9.850934267044067s
中国です。
Q: 欧州で人口が多い国は?
A: 英国です。
Q: 経済力が優れている国は?
さいごに
GPT-Jが自分の環境で動くのは楽しいですね。使用メモリはRTX3090でもギリギリなので、もう少し多くのメモリを持っているGPUを選んだほうが良いのかもしれません。あと、JAXの知識が足りていないのでTensorCoreなどの機能が使われているのか把握できていない状況です。JAXを勉強すれば、もっと推論の速度が速くなるかもしれません。
今回は実験的にHPのゲーミングPCで動作確認しましたが、Azureなど動かす場合は以下のVMが良いかもしれないです。
- Standard_ND40rs_v2 V100 32
- Standard_ND96asr_v4 A100 40
- Standard_ND96asr_v4 A100 80
実際に使うときは、運用面や動作速度など考えて、OpenAIや、Azure OpenAI を使ってみるのもいいと思います。(Azure OpenAIは、まだ申請が必要なプレビュー状態ですので、早く一般的に使えるようになってほしいですね。(2022/01/22))