2
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

【Pytorch】kaggle notebookでPytorchのGPUのコードをTPUのコードに書き換える【GPU→TPU】

Last updated at Posted at 2021-11-19

##0,はじめに
下記を書き換えることで、2021年11月19日現在は動作することを確認しましたが、cudaのversionなどの問題で動作しなくなることがあります。
その場合はおそらくpytorch-xla-env-setup.pyのversionを変更することで動くようになると推測しています。

【執筆時のcudaのversion】

.py
!nvcc --version
>Cuda compilation tools, release 11.0, V11.0.221
>Build cuda_11.0_bu.TC445_37.28845127_0

##1,ライブラリをimportする前に

.py
!curl https://raw.githubusercontent.com/pytorch/xla/master/contrib/scripts/env-setup.py -o pytorch-xla-env-setup.py
!python pytorch-xla-env-setup.py --version 1.7 --apt-packages libomp5 libopenblas-dev

##2,必要なライブラリをimport

.py
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

##3,cudaではなくxm.xla_device()を使う

.py
device = xm.xla_device()

一応確認。

.py
print(f'Using device: {device}')
>Using device: xla:1

##4,optimizerによる更新
optimiser.step()の代わりに

.py
xm.optimizer_step(optimizer,barrier=True)

##5,環境変数の設定

.py
os.environ["XLA_USE_BF16"] = "1"
os.environ["XLA_TENSOR_ALLOCATOR_MAXSIZE"] = "100000000"

##6,samplerを使う

.py
train_sampler = torch.utils.data.distributed.DistributedSampler(
train_dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=True)

valid_sampler = torch.utils.data.distributed.DistributedSampler(
valid_dataset,
num_replicas=xm.xrt_world_size(),
rank=xm.get_ordinal(),
shuffle=False)

train_dataloader = DataLoader(
train_dataset, batch_size=batch_size,
num_workers=num_workers, pin_memory=True,
sampler=train_sampler
)
valid_dataloader = DataLoader(
valid_dataset, batch_size=batch_size,
num_workers=num_workers, pin_memory=True,
sampler=valid_sampler
)

ここでDataLoaderの引数としてshuffle=Trueを付け加えるとエラーが出ます。
DataLoaderの引数にsamplerが存在していることに注意してください。

##7, modelのsaveには普通のtorchではなくtorch_xla.core.xla_modelライブラリを使う

xmとしてはじめにimportしてます。

.py
xm.save(model.state_dict(), "model.pth")

##8, modelをloadするとき

.py
model.load_state_dict(torch.load(model_name))

##最後に
問題があれば大変恐縮ですが、ご指摘いただけますと幸いです。

2
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
4

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?