LoginSignup
0
1

富岳のHorovod with PyTorch-1.7.0でJGLUEのJCommonSenseQAを動かすには

Posted at

一昨昨日の記事の続きだが、スーパーコンピュータ「富岳」のHorovod with PyTorch-1.7.0上で、JGLUEを複数ノードで動かしてみることにした。あれこれ調べてみたところ、富岳は12ノードが基本単位で、12ノードを超える通信は遅くなってしまう。通信が遅くなると、Horovodのbroadcastとかが間に合わない。何度か失敗を繰り返した後、とりあえず5×12=60ノードに落ち着いた。

#! /bin/bash
#PJM -L rscgrp=small
#PJM -L elapse=1:00:00
#PJM -L node=5x12:torus
#PJM -j
#PJM -S

G=`id | sed 's/^.*gid=[0-9]*(\([^)]*\)).*$/\1/'`
set `ls -d /vol*/$G /vol*/data/$G` $HOME
export PYTHONUSERBASE=$1/jglue
export PATH=/home/apps/oss/PyTorch-1.7.0/bin:$PYTHONUSERBASE/bin:$PATH
export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/home/apps/oss/PyTorch-1.7.0/lib64
export HF_HOME=$PYTHONUSERBASE
export TMPDIR=$PYTHONUSERBASE/tmp
mkdir -p $TMPDIR
T=$TMPDIR/transformers-4.28.1
if [ ! -d $T ]
then git clone -b v4.28.1 --depth=1 https://github.com/huggingface/transformers $T
fi
J=$TMPDIR/JGLUE
if [ ! -d $J ]
then git clone --depth=1 https://github.com/yahoojapan/JGLUE $J
     cat $J/fine-tuning/patch/transformers-4.9.2_jglue-1.1.0.patch | ( cd $T && patch -p1 )
     pip3.8 install tokenizers==0.13.3 protobuf==3.20.3 accelerate==0.20.3 --user
     ( cd $T && pip3.8 install . --user )
     pip3.8 install -r $T/examples/pytorch/text-classification/requirements.txt --user
     pip3.8 install -U tqdm packaging typing_extensions --user
fi

P=$TMPDIR/run_swag.$PJM_JOBID.$$.py
cp $T/examples/pytorch/multiple-choice/run_swag.py $P
if [ ${PJM_NODE-1} -lt 2 ]
then M=python3.8
else M="env LD_PRELOAD=libtcmalloc.so mpirun -np $PJM_NODE python3 -u"
     ex -s $P << 'EOF'
/^check_min_version/a
import horovod.torch as hvd
hvd.init()
.
/# Training/i
    trainer.create_optimizer()
    trainer.optimizer = hvd.DistributedOptimizer(trainer.optimizer, model.named_parameters())
    trainer._get_train_sampler = lambda: torch.utils.data.distributed.DistributedSampler(trainer.train_dataset, num_replicas=hvd.size(), rank=hvd.rank())
    trainer._get_eval_sampler = lambda x: torch.utils.data.distributed.DistributedSampler(x, num_replicas=hvd.size(), rank=hvd.rank())
    hvd.broadcast_parameters(model.state_dict(), root_rank=0)
.
wq
EOF
fi
$M $P --model_name_or_path KoichiYasuoka/deberta-base-japanese-wikipedia --do_train --do_eval --max_seq_length 64 --per_device_train_batch_size 16 --learning_rate 5e-05 --num_train_epochs 4 --output_dir $TMPDIR/output.${PJM_JOBID-$$} --overwrite_output_dir --train_file $J/datasets/jcommonsenseqa-v1.1/train-v1.1.json --validation_file $J/datasets/jcommonsenseqa-v1.1/valid-v1.1.json --use_fast_tokenizer True --evaluation_strategy epoch --warmup_ratio 0.1

富岳のログインノードからpjsubして、deberta-base-japanese-wikipediaのJCommonSenseQAを測ってみたところ、ファインチューニングが7分で終わり、output_direval_results.jsonに以下の結果が出力された。

{
    "epoch": 4.0,
    "eval_accuracy": 0.7894737124443054,
    "eval_loss": 0.9615128636360168,
    "eval_runtime": 1.8881,
    "eval_samples": 1119,
    "eval_samples_per_second": 592.65,
    "eval_steps_per_second": 74.147
}

eval_accuracyが0.789もあって、私(安岡孝一)の2022年6月25日の日記よりも、そこそこ高い。まあ、7分×60ノード=7時間のファインチューニング、と考えれば、それなりの結果なのかしら。

0
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
1