一昨昨日の記事の続きだが、スーパーコンピュータ「富岳」のHorovod with PyTorch-1.7.0上で、JGLUEを複数ノードで動かしてみることにした。あれこれ調べてみたところ、富岳は12ノードが基本単位で、12ノードを超える通信は遅くなってしまう。通信が遅くなると、Horovodのbroadcastとかが間に合わない。何度か失敗を繰り返した後、とりあえず5×12=60ノードに落ち着いた。
#! /bin/bash
#PJM -L rscgrp=small
#PJM -L elapse=1:00:00
#PJM -L node=5x12:torus
#PJM -j
#PJM -S
G=`id | sed 's/^.*gid=[0-9]*(\([^)]*\)).*$/\1/'`
set `ls -d /vol*/$G /vol*/data/$G` $HOME
export PYTHONUSERBASE=$1/jglue
export PATH=/home/apps/oss/PyTorch-1.7.0/bin:$PYTHONUSERBASE/bin:$PATH
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/home/apps/oss/PyTorch-1.7.0/lib64
export HF_HOME=$PYTHONUSERBASE
export TMPDIR=$PYTHONUSERBASE/tmp
mkdir -p $TMPDIR
T=$TMPDIR/transformers-4.28.1
if [ ! -d $T ]
then git clone -b v4.28.1 --depth=1 https://github.com/huggingface/transformers $T
fi
J=$TMPDIR/JGLUE
if [ ! -d $J ]
then git clone --depth=1 https://github.com/yahoojapan/JGLUE $J
cat $J/fine-tuning/patch/transformers-4.9.2_jglue-1.1.0.patch | ( cd $T && patch -p1 )
pip3.8 install tokenizers==0.13.3 protobuf==3.20.3 accelerate==0.20.3 --user
( cd $T && pip3.8 install . --user )
pip3.8 install -r $T/examples/pytorch/text-classification/requirements.txt --user
pip3.8 install -U tqdm packaging typing_extensions --user
fi
P=$TMPDIR/run_swag.$PJM_JOBID.$$.py
cp $T/examples/pytorch/multiple-choice/run_swag.py $P
if [ ${PJM_NODE-1} -lt 2 ]
then M=python3.8
else M="env LD_PRELOAD=libtcmalloc.so mpirun -np $PJM_NODE python3 -u"
ex -s $P << 'EOF'
/^check_min_version/a
import horovod.torch as hvd
hvd.init()
.
/# Training/i
trainer.create_optimizer()
trainer.optimizer = hvd.DistributedOptimizer(trainer.optimizer, model.named_parameters())
trainer._get_train_sampler = lambda: torch.utils.data.distributed.DistributedSampler(trainer.train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
trainer._get_eval_sampler = lambda x: torch.utils.data.distributed.DistributedSampler(x, num_replicas=hvd.size(), rank=hvd.rank())
hvd.broadcast_parameters(model.state_dict(), root_rank=0)
.
wq
EOF
fi
$M $P --model_name_or_path KoichiYasuoka/deberta-base-japanese-wikipedia --do_train --do_eval --max_seq_length 64 --per_device_train_batch_size 16 --learning_rate 5e-05 --num_train_epochs 4 --output_dir $TMPDIR/output.${PJM_JOBID-$$} --overwrite_output_dir --train_file $J/datasets/jcommonsenseqa-v1.1/train-v1.1.json --validation_file $J/datasets/jcommonsenseqa-v1.1/valid-v1.1.json --use_fast_tokenizer True --evaluation_strategy epoch --warmup_ratio 0.1
富岳のログインノードからpjsub
して、deberta-base-japanese-wikipediaのJCommonSenseQAを測ってみたところ、ファインチューニングが7分で終わり、output_dir
のeval_results.json
に以下の結果が出力された。
{
"epoch": 4.0,
"eval_accuracy": 0.7894737124443054,
"eval_loss": 0.9615128636360168,
"eval_runtime": 1.8881,
"eval_samples": 1119,
"eval_samples_per_second": 592.65,
"eval_steps_per_second": 74.147
}
eval_accuracy
が0.789もあって、私(安岡孝一)の2022年6月25日の日記よりも、そこそこ高い。まあ、7分×60ノード=7時間のファインチューニング、と考えれば、それなりの結果なのかしら。