LoginSignup
26
17

More than 1 year has passed since last update.

PyTorchがApple SiliconのGPUを使えるようになるらしいので試してみた

Posted at

PyTorchの次期バージョン(v1.12)がApple Silicon MacのGPUを使って学習を行えるようになるというアナウンスが出ました。プレビュー版は既に利用可能になっています。

というわけで早速試してみました。

コード

データセットとモデルはチュートリアルのFashionMNISTのやつを使ってみます。

deviceはみなさん普段は cuda を使うかと思いますが、MacのGPUの場合は mps (Metal Performance Shaders) となります。詳しくは

を参照。

コード:

main.py
import argparse
import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor

class MyNetwork(nn.Module):
    def __init__(self):
        super(MyNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

def get_default_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    elif getattr(torch.backends, "mps", None) is not None and torch.backends.mps.is_available():
        return "mps"
    else:
        return "cpu"

def train(dataloader: DataLoader, model: MyNetwork, loss_fn: nn.CrossEntropyLoss, optimizer: torch.optim.SGD, device: str) -> None:
    size = len(dataloader.dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        pred = model(X)
        loss = loss_fn(pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

def test(dataloader: DataLoader, model: MyNetwork, loss_fn: nn.CrossEntropyLoss, device: str) -> None:
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    model.eval()
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(f"Test Error:\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}\n")

def main() -> None:
    parser = argparse.ArgumentParser(description="FashionMNIST")
    parser.add_argument("--device", default=get_default_device(), help="Device to use")
    args = parser.parse_args()
    device = args.device

    training_data = datasets.FashionMNIST(root="data", train=True, download=True, transform=ToTensor())
    test_data = datasets.FashionMNIST(root="data", train=False, download=True, transform=ToTensor())

    batch_size = 64

    train_dataloader = DataLoader(training_data, batch_size=batch_size)
    test_dataloader = DataLoader(test_data, batch_size=batch_size)

    for X, y in test_dataloader:
        print(f"Shape of X [N, C, H, W]: {X.shape}")
        print(f"Shape of y: {y.shape} {y.dtype}")
        break

    print(f"Using {device} device")

    model = MyNetwork().to(device)
    print(model)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)

    epochs = 10
    for t in range(epochs):
        print(f"Epoch {t+1}\n-------------")
        train(train_dataloader, model, loss_fn, optimizer, device=device)
        test(test_dataloader, model, loss_fn, device=device)
    print("Done!")

if __name__ == "__main__":
    main()

実行結果

Mac mini (M1, 2020, GPU8コア) で試してみました。PyTorchのインストールは

pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu

みたいな感じでやります。URLにcpuが入っていますが気にしません。 torch.backends.mps.is_available() がTrueになれば成功です。

まずはCPUから。

% time python main.py --device cpu 
Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64
Using cpu device
MyNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)
Epoch 1
-------------
loss: 2.293083  [    0/60000]
loss: 2.283966  [ 6400/60000]
loss: 2.265248  [12800/60000]
loss: 2.260847  [19200/60000]
loss: 2.241223  [25600/60000]
loss: 2.214097  [32000/60000]
loss: 2.221246  [38400/60000]
loss: 2.191444  [44800/60000]
loss: 2.184505  [51200/60000]
loss: 2.152950  [57600/60000]
Test Error:
 Accuracy: 52.5%, Avg loss: 2.147460

Epoch 2
-------------
loss: 2.154118  [    0/60000]
loss: 2.147166  [ 6400/60000]
loss: 2.085644  [12800/60000]
loss: 2.106354  [19200/60000]
loss: 2.053174  [25600/60000]
loss: 1.997061  [32000/60000]
loss: 2.023799  [38400/60000]
loss: 1.948802  [44800/60000]
loss: 1.954430  [51200/60000]
loss: 1.882008  [57600/60000]
Test Error:
 Accuracy: 58.4%, Avg loss: 1.874473

Epoch 3
-------------
loss: 1.906654  [    0/60000]
loss: 1.880482  [ 6400/60000]
loss: 1.755886  [12800/60000]
loss: 1.799277  [19200/60000]
loss: 1.690940  [25600/60000]
loss: 1.644767  [32000/60000]
loss: 1.664432  [38400/60000]
loss: 1.566495  [44800/60000]
loss: 1.593532  [51200/60000]
loss: 1.487655  [57600/60000]
Test Error:
 Accuracy: 60.2%, Avg loss: 1.500291

Epoch 4
-------------
loss: 1.567383  [    0/60000]
loss: 1.536163  [ 6400/60000]
loss: 1.379712  [12800/60000]
loss: 1.453804  [19200/60000]
loss: 1.334036  [25600/60000]
loss: 1.333584  [32000/60000]
loss: 1.346803  [38400/60000]
loss: 1.272417  [44800/60000]
loss: 1.310961  [51200/60000]
loss: 1.213349  [57600/60000]
Test Error:
 Accuracy: 62.9%, Avg loss: 1.235704

Epoch 5
-------------
loss: 1.310030  [    0/60000]
loss: 1.295836  [ 6400/60000]
loss: 1.127985  [12800/60000]
loss: 1.235680  [19200/60000]
loss: 1.106518  [25600/60000]
loss: 1.139977  [32000/60000]
loss: 1.159387  [38400/60000]
loss: 1.099196  [44800/60000]
loss: 1.141269  [51200/60000]
loss: 1.060075  [57600/60000]
Test Error:
 Accuracy: 64.4%, Avg loss: 1.076839

Epoch 6
-------------
loss: 1.143180  [    0/60000]
loss: 1.148399  [ 6400/60000]
loss: 0.967611  [12800/60000]
loss: 1.103710  [19200/60000]
loss: 0.970504  [25600/60000]
loss: 1.013126  [32000/60000]
loss: 1.046940  [38400/60000]
loss: 0.993126  [44800/60000]
loss: 1.033323  [51200/60000]
loss: 0.965677  [57600/60000]
Test Error:
 Accuracy: 65.8%, Avg loss: 0.975701

Epoch 7
-------------
loss: 1.029155  [    0/60000]
loss: 1.052933  [ 6400/60000]
loss: 0.859105  [12800/60000]
loss: 1.016738  [19200/60000]
loss: 0.885197  [25600/60000]
loss: 0.923422  [32000/60000]
loss: 0.973342  [38400/60000]
loss: 0.925087  [44800/60000]
loss: 0.958950  [51200/60000]
loss: 0.901543  [57600/60000]
Test Error:
 Accuracy: 67.2%, Avg loss: 0.906411

Epoch 8
-------------
loss: 0.945191  [    0/60000]
loss: 0.985683  [ 6400/60000]
loss: 0.780831  [12800/60000]
loss: 0.955175  [19200/60000]
loss: 0.827421  [25600/60000]
loss: 0.857079  [32000/60000]
loss: 0.921167  [38400/60000]
loss: 0.879375  [44800/60000]
loss: 0.904946  [51200/60000]
loss: 0.854444  [57600/60000]
Test Error:
 Accuracy: 68.1%, Avg loss: 0.855827

Epoch 9
-------------
loss: 0.879951  [    0/60000]
loss: 0.934363  [ 6400/60000]
loss: 0.721371  [12800/60000]
loss: 0.909010  [19200/60000]
loss: 0.785734  [25600/60000]
loss: 0.806195  [32000/60000]
loss: 0.881165  [38400/60000]
loss: 0.847154  [44800/60000]
loss: 0.864123  [51200/60000]
loss: 0.817849  [57600/60000]
Test Error:
 Accuracy: 69.4%, Avg loss: 0.817021

Epoch 10
-------------
loss: 0.827321  [    0/60000]
loss: 0.892991  [ 6400/60000]
loss: 0.674449  [12800/60000]
loss: 0.873180  [19200/60000]
loss: 0.754036  [25600/60000]
loss: 0.766333  [32000/60000]
loss: 0.848583  [38400/60000]
loss: 0.823278  [44800/60000]
loss: 0.831912  [51200/60000]
loss: 0.787997  [57600/60000]
Test Error:
 Accuracy: 70.8%, Avg loss: 0.785940

Done!
python main.py --device cpu  28.78s user 9.98s system 126% cpu 30.635 total

普通ですね。

次はGPUです。

% time python main.py --device mps
Shape of X [N, C, H, W]: torch.Size([64, 1, 28, 28])
Shape of y: torch.Size([64]) torch.int64
Using mps device
MyNetwork(
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=784, out_features=512, bias=True)
    (1): ReLU()
    (2): Linear(in_features=512, out_features=512, bias=True)
    (3): ReLU()
    (4): Linear(in_features=512, out_features=10, bias=True)
  )
)
Epoch 1
-------------
loss: 2.308286  [    0/60000]
loss: 2.300420  [ 6400/60000]
loss: 2.284562  [12800/60000]
loss: 2.275385  [19200/60000]
loss: 2.261804  [25600/60000]
loss: 2.234581  [32000/60000]
loss: 2.240901  [38400/60000]
loss: 2.208639  [44800/60000]
loss: 2.214520  [51200/60000]
loss: 2.177371  [57600/60000]
Test Error:
 Accuracy: 0.0%, Avg loss: 2.178863

Epoch 2
-------------
loss: 2.192659  [    0/60000]
loss: 2.186912  [ 6400/60000]
loss: 2.135712  [12800/60000]
loss: 2.147661  [19200/60000]
loss: 2.102172  [25600/60000]
loss: 2.052881  [32000/60000]
loss: 2.078633  [38400/60000]
loss: 2.008227  [44800/60000]
loss: 2.017821  [51200/60000]
loss: 1.943297  [57600/60000]
Test Error:
 Accuracy: 0.0%, Avg loss: 1.944851

Epoch 3
-------------
loss: 1.981287  [    0/60000]
loss: 1.952881  [ 6400/60000]
loss: 1.845932  [12800/60000]
loss: 1.874648  [19200/60000]
loss: 1.764092  [25600/60000]
loss: 1.722995  [32000/60000]
loss: 1.738531  [38400/60000]
loss: 1.639645  [44800/60000]
loss: 1.657003  [51200/60000]
loss: 1.547692  [57600/60000]
Test Error:
 Accuracy: 0.0%, Avg loss: 1.566313

Epoch 4
-------------
loss: 1.635859  [    0/60000]
loss: 1.593795  [ 6400/60000]
loss: 1.450184  [12800/60000]
loss: 1.509478  [19200/60000]
loss: 1.384481  [25600/60000]
loss: 1.386362  [32000/60000]
loss: 1.399653  [38400/60000]
loss: 1.316242  [44800/60000]
loss: 1.344521  [51200/60000]
loss: 1.250790  [57600/60000]
Test Error:
 Accuracy: 0.0%, Avg loss: 1.273837

Epoch 5
-------------
loss: 1.351406  [    0/60000]
loss: 1.328287  [ 6400/60000]
loss: 1.167860  [12800/60000]
loss: 1.267555  [19200/60000]
loss: 1.138780  [25600/60000]
loss: 1.168856  [32000/60000]
loss: 1.193459  [38400/60000]
loss: 1.118520  [44800/60000]
loss: 1.153710  [51200/60000]
loss: 1.081497  [57600/60000]
Test Error:
 Accuracy: 0.0%, Avg loss: 1.097177

Epoch 6
-------------
loss: 1.167233  [    0/60000]
loss: 1.165472  [ 6400/60000]
loss: 0.988221  [12800/60000]
loss: 1.121341  [19200/60000]
loss: 0.993114  [25600/60000]
loss: 1.028384  [32000/60000]
loss: 1.068532  [38400/60000]
loss: 0.995041  [44800/60000]
loss: 1.033682  [51200/60000]
loss: 0.977396  [57600/60000]
Test Error:
 Accuracy: 0.0%, Avg loss: 0.986301

Epoch 7
-------------
loss: 1.043652  [    0/60000]
loss: 1.062958  [ 6400/60000]
loss: 0.868568  [12800/60000]
loss: 1.026042  [19200/60000]
loss: 0.904372  [25600/60000]
loss: 0.932638  [32000/60000]
loss: 0.988347  [38400/60000]
loss: 0.915102  [44800/60000]
loss: 0.952812  [51200/60000]
loss: 0.908595  [57600/60000]
Test Error:
 Accuracy: 0.0%, Avg loss: 0.912048

Epoch 8
-------------
loss: 0.954836  [    0/60000]
loss: 0.992852  [ 6400/60000]
loss: 0.783999  [12800/60000]
loss: 0.959045  [19200/60000]
loss: 0.846602  [25600/60000]
loss: 0.863988  [32000/60000]
loss: 0.932708  [38400/60000]
loss: 0.861322  [44800/60000]
loss: 0.895089  [51200/60000]
loss: 0.859447  [57600/60000]
Test Error:
 Accuracy: 0.0%, Avg loss: 0.859100

Epoch 9
-------------
loss: 0.887372  [    0/60000]
loss: 0.940760  [ 6400/60000]
loss: 0.721235  [12800/60000]
loss: 0.909224  [19200/60000]
loss: 0.805790  [25600/60000]
loss: 0.812802  [32000/60000]
loss: 0.890967  [38400/60000]
loss: 0.823702  [44800/60000]
loss: 0.851688  [51200/60000]
loss: 0.821659  [57600/60000]
Test Error:
 Accuracy: 0.0%, Avg loss: 0.819076

Epoch 10
-------------
loss: 0.833279  [    0/60000]
loss: 0.899067  [ 6400/60000]
loss: 0.672593  [12800/60000]
loss: 0.870584  [19200/60000]
loss: 0.774870  [25600/60000]
loss: 0.773825  [32000/60000]
loss: 0.857438  [38400/60000]
loss: 0.795960  [44800/60000]
loss: 0.817822  [51200/60000]
loss: 0.791271  [57600/60000]
Test Error:
 Accuracy: 0.0%, Avg loss: 0.787277

Done!
python main.py --device mps  38.84s user 17.26s system 127% cpu 43.980 total

GPUの方が遅くなりました。また、Accuracyが0.0%になっています。

まあ最初のプレビュー版ならこんなものでしょうか。速度に関してはもっと大きなモデルだと違うのかもしれません。エーアイなんもわからん。

関連してそうなissueも早速いくつか立っています:

正式版までによくなっているといいですね。

26
17
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
26
17