PyTorchの次期バージョン(v1.12)がApple Silicon MacのGPUを使って学習を行えるようになるというアナウンスが出ました。プレビュー版は既に利用可能になっています。
- Introducing Accelerated PyTorch Training on Mac | PyTorch
- GitHubのissueは GPU acceleration for Apple's M1 chip? · Issue #47702 · pytorch/pytorch でしたが、数ヶ月前にS4TF界隈の人がオフトピな投稿をしまくったせいかロックされました。
というわけで早速試してみました。
コード
データセットとモデルはチュートリアルのFashionMNISTのやつを使ってみます。
deviceはみなさん普段は cuda
を使うかと思いますが、MacのGPUの場合は mps
(Metal Performance Shaders) となります。詳しくは
を参照。
コード:
import argparse
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
class MyNetwork(nn.Module):
def __init__(self):
super(MyNetwork, self).__init__()
self.flatten = nn.Flatten()
self.linear_relu_stack = nn.Sequential(
nn.Linear(28*28, 512),
nn.ReLU(),
nn.Linear(512, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
def get_default_device() -> str:
if torch.cuda.is_available():
return "cuda"
elif getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
return "mps"
else:
return "cpu"
def train(dataloader: DataLoader, model: MyNetwork, loss_fn: nn.CrossEntropyLoss, optimizer: torch.optim.SGD, device: str) -> None:
size = len(dataloader.dataset)
model.train()
for batch, (X, y) in enumerate(dataloader):
X, y = X.to(device), y.to(device)
pred = model(X)
loss = loss_fn(pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
def test(dataloader: DataLoader, model: MyNetwork, loss_fn: nn.CrossEntropyLoss, device: str) -> None:
size = len(dataloader.dataset)
num_batches = len(dataloader)
model.eval()
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_fn(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error:\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}\n")
def main() -> None:
parser = argparse.ArgumentParser(description="FashionMNIST")
parser.add_argument("--device", default=get_default_device(), help="Device to use")
args = parser.parse_args()
device = args.device
training_data = datasets.FashionMNIST(root="data", train=True, download=True, transform=ToTensor())
test_data = datasets.FashionMNIST(root="data", train=False, download=True, transform=ToTensor())
batch_size = 64
train_dataloader = DataLoader(training_data, batch_size=batch_size)
test_dataloader = DataLoader(test_data, batch_size=batch_size)
for X, y in test_dataloader:
print(f"Shape of X [N, C, H, W]: {X.shape}")
print(f"Shape of y: {y.shape} {y.dtype}")
break
print(f"Using {device} device")
model = MyNetwork().to(device)
print(model)
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
epochs = 10
for t in range(epochs):
print(f"Epoch {t+1}\n-------------")
train(train_dataloader, model, loss_fn, optimizer, device=device)
test(test_dataloader, model, loss_fn, device=device)
print("Done!")
if __name__ == "__main__":
main()
実行結果
Mac mini (M1, 2020, GPU8コア) で試してみました。PyTorchのインストールは
pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
みたいな感じでやります。URLにcpuが入っていますが気にしません。 torch.backends.mps.is_available()
がTrueになれば成功です。
まずはCPUから。
% time python main.py --device cpu
Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64
Using cpu device
MyNetwork(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
)
Epoch 1
-------------
loss: 2.293083 [ 0/60000]
loss: 2.283966 [ 6400/60000]
loss: 2.265248 [12800/60000]
loss: 2.260847 [19200/60000]
loss: 2.241223 [25600/60000]
loss: 2.214097 [32000/60000]
loss: 2.221246 [38400/60000]
loss: 2.191444 [44800/60000]
loss: 2.184505 [51200/60000]
loss: 2.152950 [57600/60000]
Test Error:
Accuracy: 52.5%, Avg loss: 2.147460
Epoch 2
-------------
loss: 2.154118 [ 0/60000]
loss: 2.147166 [ 6400/60000]
loss: 2.085644 [12800/60000]
loss: 2.106354 [19200/60000]
loss: 2.053174 [25600/60000]
loss: 1.997061 [32000/60000]
loss: 2.023799 [38400/60000]
loss: 1.948802 [44800/60000]
loss: 1.954430 [51200/60000]
loss: 1.882008 [57600/60000]
Test Error:
Accuracy: 58.4%, Avg loss: 1.874473
Epoch 3
-------------
loss: 1.906654 [ 0/60000]
loss: 1.880482 [ 6400/60000]
loss: 1.755886 [12800/60000]
loss: 1.799277 [19200/60000]
loss: 1.690940 [25600/60000]
loss: 1.644767 [32000/60000]
loss: 1.664432 [38400/60000]
loss: 1.566495 [44800/60000]
loss: 1.593532 [51200/60000]
loss: 1.487655 [57600/60000]
Test Error:
Accuracy: 60.2%, Avg loss: 1.500291
Epoch 4
-------------
loss: 1.567383 [ 0/60000]
loss: 1.536163 [ 6400/60000]
loss: 1.379712 [12800/60000]
loss: 1.453804 [19200/60000]
loss: 1.334036 [25600/60000]
loss: 1.333584 [32000/60000]
loss: 1.346803 [38400/60000]
loss: 1.272417 [44800/60000]
loss: 1.310961 [51200/60000]
loss: 1.213349 [57600/60000]
Test Error:
Accuracy: 62.9%, Avg loss: 1.235704
Epoch 5
-------------
loss: 1.310030 [ 0/60000]
loss: 1.295836 [ 6400/60000]
loss: 1.127985 [12800/60000]
loss: 1.235680 [19200/60000]
loss: 1.106518 [25600/60000]
loss: 1.139977 [32000/60000]
loss: 1.159387 [38400/60000]
loss: 1.099196 [44800/60000]
loss: 1.141269 [51200/60000]
loss: 1.060075 [57600/60000]
Test Error:
Accuracy: 64.4%, Avg loss: 1.076839
Epoch 6
-------------
loss: 1.143180 [ 0/60000]
loss: 1.148399 [ 6400/60000]
loss: 0.967611 [12800/60000]
loss: 1.103710 [19200/60000]
loss: 0.970504 [25600/60000]
loss: 1.013126 [32000/60000]
loss: 1.046940 [38400/60000]
loss: 0.993126 [44800/60000]
loss: 1.033323 [51200/60000]
loss: 0.965677 [57600/60000]
Test Error:
Accuracy: 65.8%, Avg loss: 0.975701
Epoch 7
-------------
loss: 1.029155 [ 0/60000]
loss: 1.052933 [ 6400/60000]
loss: 0.859105 [12800/60000]
loss: 1.016738 [19200/60000]
loss: 0.885197 [25600/60000]
loss: 0.923422 [32000/60000]
loss: 0.973342 [38400/60000]
loss: 0.925087 [44800/60000]
loss: 0.958950 [51200/60000]
loss: 0.901543 [57600/60000]
Test Error:
Accuracy: 67.2%, Avg loss: 0.906411
Epoch 8
-------------
loss: 0.945191 [ 0/60000]
loss: 0.985683 [ 6400/60000]
loss: 0.780831 [12800/60000]
loss: 0.955175 [19200/60000]
loss: 0.827421 [25600/60000]
loss: 0.857079 [32000/60000]
loss: 0.921167 [38400/60000]
loss: 0.879375 [44800/60000]
loss: 0.904946 [51200/60000]
loss: 0.854444 [57600/60000]
Test Error:
Accuracy: 68.1%, Avg loss: 0.855827
Epoch 9
-------------
loss: 0.879951 [ 0/60000]
loss: 0.934363 [ 6400/60000]
loss: 0.721371 [12800/60000]
loss: 0.909010 [19200/60000]
loss: 0.785734 [25600/60000]
loss: 0.806195 [32000/60000]
loss: 0.881165 [38400/60000]
loss: 0.847154 [44800/60000]
loss: 0.864123 [51200/60000]
loss: 0.817849 [57600/60000]
Test Error:
Accuracy: 69.4%, Avg loss: 0.817021
Epoch 10
-------------
loss: 0.827321 [ 0/60000]
loss: 0.892991 [ 6400/60000]
loss: 0.674449 [12800/60000]
loss: 0.873180 [19200/60000]
loss: 0.754036 [25600/60000]
loss: 0.766333 [32000/60000]
loss: 0.848583 [38400/60000]
loss: 0.823278 [44800/60000]
loss: 0.831912 [51200/60000]
loss: 0.787997 [57600/60000]
Test Error:
Accuracy: 70.8%, Avg loss: 0.785940
Done!
python main.py --device cpu 28.78s user 9.98s system 126% cpu 30.635 total
普通ですね。
次はGPUです。
% time python main.py --device mps
Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64
Using mps device
MyNetwork(
(flatten): Flatten(start_dim=1, end_dim=-1)
(linear_relu_stack): Sequential(
(0): Linear(in_features=784, out_features=512, bias=True)
(1): ReLU()
(2): Linear(in_features=512, out_features=512, bias=True)
(3): ReLU()
(4): Linear(in_features=512, out_features=10, bias=True)
)
)
Epoch 1
-------------
loss: 2.308286 [ 0/60000]
loss: 2.300420 [ 6400/60000]
loss: 2.284562 [12800/60000]
loss: 2.275385 [19200/60000]
loss: 2.261804 [25600/60000]
loss: 2.234581 [32000/60000]
loss: 2.240901 [38400/60000]
loss: 2.208639 [44800/60000]
loss: 2.214520 [51200/60000]
loss: 2.177371 [57600/60000]
Test Error:
Accuracy: 0.0%, Avg loss: 2.178863
Epoch 2
-------------
loss: 2.192659 [ 0/60000]
loss: 2.186912 [ 6400/60000]
loss: 2.135712 [12800/60000]
loss: 2.147661 [19200/60000]
loss: 2.102172 [25600/60000]
loss: 2.052881 [32000/60000]
loss: 2.078633 [38400/60000]
loss: 2.008227 [44800/60000]
loss: 2.017821 [51200/60000]
loss: 1.943297 [57600/60000]
Test Error:
Accuracy: 0.0%, Avg loss: 1.944851
Epoch 3
-------------
loss: 1.981287 [ 0/60000]
loss: 1.952881 [ 6400/60000]
loss: 1.845932 [12800/60000]
loss: 1.874648 [19200/60000]
loss: 1.764092 [25600/60000]
loss: 1.722995 [32000/60000]
loss: 1.738531 [38400/60000]
loss: 1.639645 [44800/60000]
loss: 1.657003 [51200/60000]
loss: 1.547692 [57600/60000]
Test Error:
Accuracy: 0.0%, Avg loss: 1.566313
Epoch 4
-------------
loss: 1.635859 [ 0/60000]
loss: 1.593795 [ 6400/60000]
loss: 1.450184 [12800/60000]
loss: 1.509478 [19200/60000]
loss: 1.384481 [25600/60000]
loss: 1.386362 [32000/60000]
loss: 1.399653 [38400/60000]
loss: 1.316242 [44800/60000]
loss: 1.344521 [51200/60000]
loss: 1.250790 [57600/60000]
Test Error:
Accuracy: 0.0%, Avg loss: 1.273837
Epoch 5
-------------
loss: 1.351406 [ 0/60000]
loss: 1.328287 [ 6400/60000]
loss: 1.167860 [12800/60000]
loss: 1.267555 [19200/60000]
loss: 1.138780 [25600/60000]
loss: 1.168856 [32000/60000]
loss: 1.193459 [38400/60000]
loss: 1.118520 [44800/60000]
loss: 1.153710 [51200/60000]
loss: 1.081497 [57600/60000]
Test Error:
Accuracy: 0.0%, Avg loss: 1.097177
Epoch 6
-------------
loss: 1.167233 [ 0/60000]
loss: 1.165472 [ 6400/60000]
loss: 0.988221 [12800/60000]
loss: 1.121341 [19200/60000]
loss: 0.993114 [25600/60000]
loss: 1.028384 [32000/60000]
loss: 1.068532 [38400/60000]
loss: 0.995041 [44800/60000]
loss: 1.033682 [51200/60000]
loss: 0.977396 [57600/60000]
Test Error:
Accuracy: 0.0%, Avg loss: 0.986301
Epoch 7
-------------
loss: 1.043652 [ 0/60000]
loss: 1.062958 [ 6400/60000]
loss: 0.868568 [12800/60000]
loss: 1.026042 [19200/60000]
loss: 0.904372 [25600/60000]
loss: 0.932638 [32000/60000]
loss: 0.988347 [38400/60000]
loss: 0.915102 [44800/60000]
loss: 0.952812 [51200/60000]
loss: 0.908595 [57600/60000]
Test Error:
Accuracy: 0.0%, Avg loss: 0.912048
Epoch 8
-------------
loss: 0.954836 [ 0/60000]
loss: 0.992852 [ 6400/60000]
loss: 0.783999 [12800/60000]
loss: 0.959045 [19200/60000]
loss: 0.846602 [25600/60000]
loss: 0.863988 [32000/60000]
loss: 0.932708 [38400/60000]
loss: 0.861322 [44800/60000]
loss: 0.895089 [51200/60000]
loss: 0.859447 [57600/60000]
Test Error:
Accuracy: 0.0%, Avg loss: 0.859100
Epoch 9
-------------
loss: 0.887372 [ 0/60000]
loss: 0.940760 [ 6400/60000]
loss: 0.721235 [12800/60000]
loss: 0.909224 [19200/60000]
loss: 0.805790 [25600/60000]
loss: 0.812802 [32000/60000]
loss: 0.890967 [38400/60000]
loss: 0.823702 [44800/60000]
loss: 0.851688 [51200/60000]
loss: 0.821659 [57600/60000]
Test Error:
Accuracy: 0.0%, Avg loss: 0.819076
Epoch 10
-------------
loss: 0.833279 [ 0/60000]
loss: 0.899067 [ 6400/60000]
loss: 0.672593 [12800/60000]
loss: 0.870584 [19200/60000]
loss: 0.774870 [25600/60000]
loss: 0.773825 [32000/60000]
loss: 0.857438 [38400/60000]
loss: 0.795960 [44800/60000]
loss: 0.817822 [51200/60000]
loss: 0.791271 [57600/60000]
Test Error:
Accuracy: 0.0%, Avg loss: 0.787277
Done!
python main.py --device mps 38.84s user 17.26s system 127% cpu 43.980 total
GPUの方が遅くなりました。また、Accuracyが0.0%になっています。
まあ最初のプレビュー版ならこんなものでしょうか。速度に関してはもっと大きなモデルだと違うのかもしれません。エーアイなんもわからん。
関連してそうなissueも早速いくつか立っています:
- Tensor will become all zero on MPS after switching to other softwares · Issue #77829 · pytorch/pytorch
- MPS device appears much slower than CPU on M1 Mac Pro · Issue #77799 · pytorch/pytorch
正式版までによくなっているといいですね。