LoginSignup
3
5

GPT-JをWindows11のGPU環境で動かしてみた

Last updated at Posted at 2022-01-22

こんにちは。今回は、GPT-3に基づいて作成されたEleutherAIのGPT-Jをmesh-transformer-jaxを使用して自分の環境で動かしたメモです。(GPT-NeoX-20Bを動かしたメモはこちら)

また、今回は以下の記事にあるように、Windows 11のDocker Desktop環境で動かしてみます。使用しているGPUはNVIDIA RTX 3090です。

TensorFlowのコンテナで環境を整える

今回はNVIDIAのTensorFlowコンテナを使用します。mesh-transformer-jaxをcloneすして作業するためのフォルダをフォルダをD:\work\gpt-jとして作成しています。
TensorFlowコンテナの起動には、Jupyter Notebookを使うためのポート転送設定と、フォルダのマウントを設定しておきます。

docker run --gpus all -it -p 8888:8888 -v D:\work\gpt-j:/gpt-j nvcr.io/nvidia/tensorflow:21.12-tf2-py3 bash

起動したら、mesh-transformer-jaxをcloneします。

cd /gpt-j
git clone https://github.com/kingoflolz/mesh-transformer-jax.git
cd mesh-transformer-jax

この requirements.txt に書かれているtensorflowパッケージは何故かcpu版を指定しているので、普通のtensorflowパッケージを指定するようにrequirements.txtを修正します。あと、cudaに対応したjaxlibをインストールしておきます。

sed -e 's/tensorflow-cpu/tensorflow/' requirements.txt > new_requirements.txt
pip install -r new_requirements.txt
pip install jax==0.2.12
pip install jaxlib==0.1.68+cuda111 -f https://storage.googleapis.com/jax-releases/jax_releases.html
export XLA_PYTHON_CLIENT_PREALLOCATE=false
export XLA_PYTHON_CLIENT_ALLOCATOR=platform

次にGPT-J-6Bのスリム版のパラメーターをダウンロードして、展開しておきます。

wget -c https://mystic.the-eye.eu/public/AI/GPT-J-6B/step_383500_slim.tar.zstd
tar -I zstd -xf step_383500_slim.tar.zstd

GPT-Jの動作を確認をする

resharding_example.pyを実行してGPT-Jが動作するか確認します。

sed -e 's/infer("EleutherAI is")/print(infer("EleutherAI is"))/' resharding_example.py > resharding_example2.py
python resharding_example2.py

実行結果(一部抜粋)

completion done in 98.40319633483887s
[' a single player, strategy game ....

一回目の推論は時間がかかりますが、"EleutherAI is"としているので、その後の文字列が何か出力されていれば動作しています。(間違っているかもしれません)

いろいろ試してみる

他のプロンプト試すときはJupyter Notebookの方が便利です。上で作成したresharding_example2.pyの最終行に書かれているinferを書き換えて試してみます。

top_p = 0.9
temp = 1

context = '''私は真実を答える賢い質問応答ボットです。 
Q: 日本の人口は?
A: 1.2億人です。
Q: 世界で一番人口が多い国は?
A: '''

print(context)
print(infer(top_p=top_p, temp=temp, gen_len=64, context=context)[0])

実行結果

私は真実を答える賢い質問応答ボットです。 
Q: 日本の人口は?
A: 1.2億人です。
Q: 世界で一番人口が多い国は?
A: 
completion done in 9.850934267044067s
中国です。
Q: 欧州で人口が多い国は?
A: 英国です。
Q: 経済力が優れている国は?

さいごに

GPT-Jが自分の環境で動くのは楽しいですね。使用メモリはRTX3090でもギリギリなので、もう少し多くのメモリを持っているGPUを選んだほうが良いのかもしれません。あと、JAXの知識が足りていないのでTensorCoreなどの機能が使われているのか把握できていない状況です。JAXを勉強すれば、もっと推論の速度が速くなるかもしれません。

今回は実験的にHPのゲーミングPCで動作確認しましたが、Azureなど動かす場合は以下のVMが良いかもしれないです。

  • Standard_ND40rs_v2 V100 32
  • Standard_ND96asr_v4 A100 40
  • Standard_ND96asr_v4 A100 80

実際に使うときは、運用面や動作速度など考えて、OpenAIや、Azure OpenAI を使ってみるのもいいと思います。(Azure OpenAIは、まだ申請が必要なプレビュー状態ですので、早く一般的に使えるようになってほしいですね。(2022/01/22))

3
5
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
5