##0,はじめに
下記を書き換えることで、2021年11月19日現在は動作することを確認しましたが、cudaのversionなどの問題で動作しなくなることがあります。
その場合はおそらくpytorch-xla-env-setup.pyのversionを変更することで動くようになると推測しています。
【執筆時のcudaのversion】
.py
!nvcc --version
>Cuda compilation tools, release 11.0, V11.0.221
>Build cuda_11.0_bu.TC445_37.28845127_0
##1,ライブラリをimportする前に
.py
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version 1.7 --apt-packages libomp5 libopenblas-dev
##2,必要なライブラリをimport
.py
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp
##3,cudaではなくxm.xla_device()を使う
.py
device = xm.xla_device()
一応確認。
.py
print(f'Using device: {device}')
>Using device: xla:1
##4,optimizerによる更新
optimiser.step()の代わりに
.py
xm.optimizer_step(optimizer,barrier=True)
##5,環境変数の設定
.py
os.environ["XLA_USE_BF16"] = "1"
os.environ["XLA_TENSOR_ALLOCATOR_MAXSIZE"] = "100000000"
##6,samplerを使う
.py
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=True)
valid_sampler = torch.utils.data.distributed.DistributedSampler(
valid_dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=False)
train_dataloader = DataLoader(
train_dataset, batch_size=batch_size,
num_workers=num_workers, pin_memory=True,
sampler=train_sampler
)
valid_dataloader = DataLoader(
valid_dataset, batch_size=batch_size,
num_workers=num_workers, pin_memory=True,
sampler=valid_sampler
)
ここでDataLoaderの引数としてshuffle=Trueを付け加えるとエラーが出ます。
DataLoaderの引数にsamplerが存在していることに注意してください。
##7, modelのsaveには普通のtorchではなくtorch_xla.core.xla_modelライブラリを使う
xmとしてはじめにimportしてます。
.py
xm.save(model.state_dict(), "model.pth")
##8, modelをloadするとき
.py
model.load_state_dict(torch.load(model_name))
##最後に
問題があれば大変恐縮ですが、ご指摘いただけますと幸いです。