0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Apple siliconマシンでPyTorchのdeviceが一致しない場合の対処法

Last updated at Posted at 2025-01-06

Nvidia GPU のマシンでは動いていたコードがApple siliconのマシンでは Input type (MPSFloatType) and weight type (torch.FloatTensor) should be the same と怒られた.
単純だけど意外と気づくのに時間がかかったので記録しておく.

忙しい人へ(結論)

torchinfo.summary をApple siliconで使う場合は引数でdeviceを指定する.

from torchinfo import summary
summary(model, input_size=(1, 1), device="mps")

本編

環境

MacBook Air M3
macOS Sonoma 14.5
Python : 3.11.11
torch : 2.7.0.dev20250105
torchinfo : 1.8.0

問題のコード

import torch
import torch.nn as nn
from torchinfo import summary

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print( "Device      :", device)

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.fc1 = nn.Linear(1, 1, bias=False)

    def forward(self, x):
        x = self.fc1(x)
        return x
    
model = Model().to(device)
def check_device(model):
    for name, param in model.named_parameters():
        print(f"{name} is on {param.device}")
        
check_device(model)
summary(model, input_size=(1, 1)) # <- deviceを指定していない

def eval():
    check_device(model)
    model.eval()
    x = torch.randn(1, 1).to(device)
    y = model(x)

def train():
    check_device(model)
    model.train()
    x = torch.randn(1, 1).to(device)
    y = model(x)
    
train()
eval()

このコードを実行すると,
Input type (MPSFloatType) and weight type (torch.FloatTensor) should be the sameとか,
Tensor for argument weight is on cpu but expected on mps とかでエラーになる.

各段階でのdeviceを確認すると,一回目のcheck_device(model)ではmpsになるが,二回目ではcpuになる.つまり悪さをしているのはsummaryということになる.

ではtorchinfo.summaryは何をしているのか?

原因 (torchinfo.summary の挙動)

torchinfosummary関数は以下のように書かれている.
(関連部分のみ抜粋)

torchinfo.py
def summary(
    model: nn.Module,
    input_size: INPUT_SIZE_TYPE | None = None,
    input_data: INPUT_DATA_TYPE | None = None,
    ...
    device: torch.device | str | None = None,
    ...
) -> ModelStatistics:
    ...
    if device is None:
        device = get_device(model, input_data)
    elif isinstance(device, str):
        device = torch.device(device)

このように引数でdeviceを指定しない場合, get_device関数で自動的にdevice情報を取得している.

では,get_device関数はどうなっているのか?

torchinfo.py
def get_device(
    model: nn.Module, input_data: INPUT_DATA_TYPE | None
) -> torch.device | None:
    if input_data is None:
        try:
            model_parameter = next(model.parameters())
        except StopIteration:
            model_parameter = None

        if model_parameter is not None and model_parameter.is_cuda:
            return model_parameter.device
        return torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return None

さて,お気づきの通りmpsの記述がどこにもない.
get_device関数はcudaの確認はしてくれるが,mpsの場合devicecpuにされてしまう.

解決策

よって解決策は3通り.

1. torchinfo.summary を使わない.

# summary(model, input_size=(1, 1)) # <- を削除

2. 引数でdeviceを指定する.

summary(model, input_size=(1, 1), device=device)
# 今回のコードでは summary(model, input_size=(1, 1), device="mps") と同じ意味

3. torchinfo.get_devicempsに対応させる.

get_device関数を以下のように改造.

torchinfo
def get_device(
    model: nn.Module, input_data: INPUT_DATA_TYPE | None
) -> torch.device | None:
    if input_data is None:
        try:
            model_parameter = next(model.parameters())
        except StopIteration:
            model_parameter = None

        if model_parameter is not None and model_parameter.is_cuda:
            return model_parameter.device
        # return torch.device("cuda" if torch.cuda.is_available() else "cpu")
        if torch.cuda.is_available():
            device = torch.device("cuda")
        elif torch.backends.mps.is_available():
            device = torch.device("mps")
        else:
            device = torch.device("cpu")
        return device
    return None

以上,Apple siliconマシンでdeviceが一致しない時の対処法でした.

0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?