Nvidia GPU のマシンでは動いていたコードがApple siliconのマシンでは Input type (MPSFloatType) and weight type (torch.FloatTensor) should be the same
と怒られた.
単純だけど意外と気づくのに時間がかかったので記録しておく.
忙しい人へ(結論)
torchinfo.summary
をApple siliconで使う場合は引数でdevice
を指定する.
from torchinfo import summary
summary(model, input_size=(1, 1), device="mps")
本編
環境
MacBook Air M3
macOS Sonoma 14.5
Python : 3.11.11
torch : 2.7.0.dev20250105
torchinfo : 1.8.0
問題のコード
import torch
import torch.nn as nn
from torchinfo import summary
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
os.environ["KMP_DUPLICATE_LIB_OK"] = "True"
device = torch.device("mps")
else:
device = torch.device("cpu")
print( "Device :", device)
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
self.fc1 = nn.Linear(1, 1, bias=False)
def forward(self, x):
x = self.fc1(x)
return x
model = Model().to(device)
def check_device(model):
for name, param in model.named_parameters():
print(f"{name} is on {param.device}")
check_device(model)
summary(model, input_size=(1, 1)) # <- deviceを指定していない
def eval():
check_device(model)
model.eval()
x = torch.randn(1, 1).to(device)
y = model(x)
def train():
check_device(model)
model.train()
x = torch.randn(1, 1).to(device)
y = model(x)
train()
eval()
このコードを実行すると,
Input type (MPSFloatType) and weight type (torch.FloatTensor) should be the same
とか,
Tensor for argument weight is on cpu but expected on mps
とかでエラーになる.
各段階でのdevice
を確認すると,一回目のcheck_device(model)
ではmps
になるが,二回目ではcpu
になる.つまり悪さをしているのはsummary
ということになる.
ではtorchinfo.summary
は何をしているのか?
原因 (torchinfo.summary
の挙動)
torchinfo
のsummary
関数は以下のように書かれている.
(関連部分のみ抜粋)
def summary(
model: nn.Module,
input_size: INPUT_SIZE_TYPE | None = None,
input_data: INPUT_DATA_TYPE | None = None,
...
device: torch.device | str | None = None,
...
) -> ModelStatistics:
...
if device is None:
device = get_device(model, input_data)
elif isinstance(device, str):
device = torch.device(device)
このように引数でdevice
を指定しない場合, get_device
関数で自動的にdevice
情報を取得している.
では,get_device
関数はどうなっているのか?
def get_device(
model: nn.Module, input_data: INPUT_DATA_TYPE | None
) -> torch.device | None:
if input_data is None:
try:
model_parameter = next(model.parameters())
except StopIteration:
model_parameter = None
if model_parameter is not None and model_parameter.is_cuda:
return model_parameter.device
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return None
さて,お気づきの通りmps
の記述がどこにもない.
get_device
関数はcuda
の確認はしてくれるが,mps
の場合device
がcpu
にされてしまう.
解決策
よって解決策は3通り.
1. torchinfo.summary
を使わない.
# summary(model, input_size=(1, 1)) # <- を削除
2. 引数でdevice
を指定する.
summary(model, input_size=(1, 1), device=device)
# 今回のコードでは summary(model, input_size=(1, 1), device="mps") と同じ意味
3. torchinfo.get_device
をmps
に対応させる.
get_device
関数を以下のように改造.
def get_device(
model: nn.Module, input_data: INPUT_DATA_TYPE | None
) -> torch.device | None:
if input_data is None:
try:
model_parameter = next(model.parameters())
except StopIteration:
model_parameter = None
if model_parameter is not None and model_parameter.is_cuda:
return model_parameter.device
# return torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
return device
return None
以上,Apple siliconマシンでdeviceが一致しない時の対処法でした.