※この記事は東大電気電子情報系EEICのアドベントカレンダーとして寄稿しています
AppleがMLXとかいうMLライブラリを出しました
Pytorch Likeな書き心地で, Apple製品に最適化されているライブラリらしいです。
Pytorchにどれぐらい似ているかを確認するために
実際にMLXでのTransformerを実装例を見てみましょう。
Pytorchに書き慣れている人は本当ににていることがわかるかと思います。
Pytorchと命名規則が同じならimport fileを変えるだけなので互換性のある進化をして欲しいなと思います。
# Copyright © 2023 Apple Inc.
import math
import time
import numpy as np
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_flatten
import datasets
class TransformerLM(nn.Module):
def __init__(self, vocab_size: int, num_layers: int, dims: int, num_heads: int):
super().__init__()
self.embedding = nn.Embedding(vocab_size, dims)
self.transformer = nn.TransformerEncoder(num_layers, dims, num_heads)
self.out_proj = nn.Linear(dims, vocab_size)
def __call__(self, x):
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
x = self.embedding(x)
x = self.transformer(x, mask)
return self.out_proj(x)
def loss(self, x, y, reduce=True):
logits = self(x)
losses = nn.losses.cross_entropy(logits, y)
mx.simplify(losses)
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
Apple製品ってGPUついてたっけ
実はApple SiliconのMacにはmpsというGPUがあります。参考記事→PyTorchをM1 MacBook のGPU(MPS)で動かす.実行時間の検証もしたよ
MacUserでGPUを買うお金もない時代は自分もmpsにはよくお世話にはなりました。
今後MLXが既存のTensorFlow, JAX, Pytorchとの差別化を図るならmpsに最適化されたライブラリになっていくんじゃないかなと思います。
他にMLXをわざわざ使う利点ってあるの?
- あります
どうやらMLXは最近のML論文の成果がnnライブラリ内に実装済みだそうです。例えば, 回転位置エンコーディング(RoPE)などですね。これはPytorchではScratchで実装するしかないみたいですが, MLXには既にあります。
個人的な意見
Keras 3に入ってくれないとフレームワークが多すぎて困る
まとめ
MLXというAppleが制作したMLライブラリの紹介をしました。Macが自社製品のハードウェアに特化して動く高速かつ軽量なソフトウェアを制作したかったんじゃないかなと考えています。
オマケ(本編)
こちらをご覧ください
なんとMLXを利用して最近のMLモデルの実装例が全部あります。マジで激アツです。
しかもPytorchのParameterをMLXに変換までしてくれるコードもありました。マジで激アツです。
しかも他ライブラリとの比較のために, Transformerの実装例ではTF/JAX/Pytorch/MLXの4つの実装例もあります。
WhisperはPytorchでの実装例もありました。
これでこの記事は以上です。
12/7 21:54追記 「ちょっと待て」
MLXでのTransformerの実装をもう一回見てみましょう。
# Copyright © 2023 Apple Inc.
import math
import time
import numpy as np
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from mlx.utils import tree_flatten
import datasets
class TransformerLM(nn.Module):
def __init__(self, vocab_size: int, num_layers: int, dims: int, num_heads: int):
super().__init__()
self.embedding = nn.Embedding(vocab_size, dims)
self.transformer = nn.TransformerEncoder(num_layers, dims, num_heads)
self.out_proj = nn.Linear(dims, vocab_size)
def __call__(self, x):
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
x = self.embedding(x)
x = self.transformer(x, mask)
return self.out_proj(x)
def loss(self, x, y, reduce=True):
logits = self(x)
losses = nn.losses.cross_entropy(logits, y)
mx.simplify(losses)
return mx.mean(losses) if reduce else mx.mean(losses, axis=(-1, -2))
Pytorchでforwardメソッドに当たる部分が__call__じゃないか
MLXで書いたコードはそのままPytorchに移せないじゃないか