本記事は、京都大学人工知能研究会KaiRAのAdvent Calender 20日目の記事です。
やったこと
CIFAR-10のtrainデータでResNetを学習、CIFAR-10のtestデータでvalidationを行いました。
(本当はtrainデータをsplitしてvalidationデータを作った方が良いのでしょうが、ここでは簡易的にtestデータを使ってしまいました。)
本当はJax/Flaxのメリットを生かした最適化をバリバリするべきなのでしょうが(vmap, pmapなど)、今回は行いませんでした。
PyTorchと違うと感じたこと
- 初期化の時などに、乱数のgeneratorを作ってしっかり渡さないといけない。(dropoutを使う場合はdropout用の乱数generatorも渡さないといけない。)
@jax.jit
def initialize(params_rng):
init_rngs = {'params': params_rng}
input_shape = (1, config.image_size, config.image_size, 3)
variables = config.model.init(init_rngs, jnp.ones(input_shape, jnp.float32), train=False)
return variables
variables = initialize(config.seed_rng)
- モデルのweightはmodelとは別に管理する。また、モデルの個々のweightは変更することはできず、weight全体を代入して変更を行う(文章では言いにくいです…)そのため、TrainStateを使ってパラメータを管理して学習を行う。
class TrainState(train_state.TrainState):
batch_stats: Any
state = TrainState.create(
apply_fn = config.model.apply,
params = variables['params'],
batch_stats = variables['batch_stats'],
tx = tx
)
@jax.jit
def train_step(state, batch):
batch = {'image':batch[0], 'label':batch[1]}
def loss_fn(params):
variables = {'params': params, 'batch_stats': state.batch_stats}
logits, new_model_state = state.apply_fn(
variables,
batch['image'],
train=True,
mutable='batch_stats'
)
loss = cross_entropy_loss(logits, batch['label'])
return loss, (new_model_state, logits)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
aux, grads = grad_fn(state.params)
new_model_state, logits = aux[1]
metrics = compute_metrics(logits, batch['label'])
new_state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])
return new_state, metrics
@jax.jit
def eval_step(state, batch):
batch = {'image':batch[0], 'label':batch[1]}
variables = {'params': state.params, 'batch_stats': state.batch_stats}
logits = state.apply_fn(variables, batch['image'], train=False, mutable=False)
metrics = compute_metrics(logits, batch['label'])
return metrics
def train(state, train_loader, val_loader):
history = defaultdict(list)
for epoch in range(1, config.epochs+1):
print(f'{epoch}/{config.epochs}')
train_metrics = []
print('train')
for i, batch in enumerate(tqdm(train_loader)):
state, metrics = train_step(state, batch)
train_metrics.append(metrics)
history = log_metrics(train_metrics, history, 'train')
val_metrics = []
print('validation')
for i, batch in enumerate(tqdm(val_loader)):
metrics = eval_step(state, batch)
val_metrics.append(metrics)
history = log_metrics(val_metrics, history, 'val')
return history
- Datasetを読み込む関数は用意されていない。なので、PyTorchかTensorflowか、はたまた別のライブラリのものを使うことになるだろう。(もしかしたらきちんと探せば何かしらあるかもしれないですが…)PytorchのDataLoaderを使ってしまうと、その中の関数でjax.jitが使えなくなってしまう??
感想
上に挙げたような違いはありますが、Pytorchと似たような感覚で書けました。(逆に、上のような違いがあることを知るのに少しだけ苦労した)
乱数の話を書きましたが、個人的にはPyTorchの乱数のseedは忘れてないかの不安感ときちんと設定できているのかどうかの不安感があるので、Flaxの書き方の方が嬉しい感じがします。
Pytorchとは違い、勾配を求めるコードは自分で書いて、勾配を使ってパラメータの更新をするのはライブラリに投げる形式なので、PyTorchみたいにloss.backward()でパラメータの更新するよりも、自分で書いていて納得感がありました。(当然ながらPyTorchの方が記述量的には楽ですが。)
個人的には「(コードが分かりにくくならない範囲では)速いは基本的に正義」だと思っているので、是非もっと普及して若干マイナーな分野における事前学習済みモデルなどもJax/Flaxで準備されるようになってもらえたら、もっと自信をもってバリバリ使えるのにな~、と思っています。
コードの書き方自体はかなり好みではあるので、みなさんどんどん使っていただいてもっと普及していったら嬉しいです。
今回書いたコード
一応以下にコードを貼っておきます。
testのaccuracyは83%でした。
正直言ってチューニングをしようと思ったらまだまだできることはたくさんあると思いますが、今回はここまでということで、チューニングの続きはkaggleで頑張りたいと思います。
# Import
import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state, common_utils
import optax
import torch
import torchvision
import torchvision.transforms as transforms
import albumentations as albu
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
from collections import defaultdict
from typing import Any, Callable, Sequence, Tuple
from functools import partial
from tqdm.auto import tqdm
# Model
ModuleDef = Any
class ResNetBlock(nn.Module):
"""ResNet block"""
filters: int
conv: ModuleDef
norm: ModuleDef
act: Callable
strides: Tuple[int, int] = (1, 1)
@nn.compact
def __call__(self, x):
residual = x
y = self.conv(self.filters, (3, 3), self.strides)(x)
y = self.norm()(y)
y = self.act(y)
y = self.conv(self.filters, (3, 3))(y)
y = self.norm(scale_init=nn.initializers.zeros_init())(y)
if residual.shape != y.shape:
residual = self.conv(self.filters, (1, 1), self.strides, name='conv_proj')(residual)
residual = self.norm(name='name_proj')(residual)
return self.act(residual + y)
class BottleneckResNetBlock(nn.Module):
"""Bottleneck ResNet block"""
filters: int
conv: ModuleDef
norm: ModuleDef
act: Callable
strides: Tuple[int, int] = (1, 1)
@nn.compact
def __call__(self, x):
residual = x
y = self.conv(self.filters, (1, 1))(x)
y = self.norm()(y)
y = self.act(y)
y = self.conv(self.filters, (3, 3), self.strides)(y)
y = self.norm()(y)
y = self.act(y)
y = self.conv(self.filters * 4, (1, 1))(y)
y = self.norm(scale_init=nn.initializers.zeros_init())(y)
if residual.shape != y.shape:
residual = self.conv(self.filters * 4, (1, 1), self.strides, name='conv_proj')(residual)
residual = self.norm(name='norm_proj')(residual)
return self.act(residual + y)
class ResNet(nn.Module):
"""ResNetV1.5."""
stage_sizes: Sequence[int]
block_cls: ModuleDef
num_classes: int
num_filters: int = 64
dtype: Any = jnp.float32
act: Callable = nn.relu
conv: ModuleDef = nn.Conv
@nn.compact
def __call__(self, x, train: bool = True):
conv = partial(self.conv, use_bias=False, dtype=self.dtype)
norm = partial(
nn.BatchNorm,
use_running_average=not train,
momentum=0.9,
epsilon=1e-5,
dtype=self.dtype,
# axis_name='batch'
)
x = conv(
self.num_filters,
(7, 7),
(2, 2),
padding=[(3, 3), (3, 3)],
name='conv_init'
)(x)
x = norm(name='bn_init')(x)
x = nn.relu(x)
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
for i, block_size in enumerate(self.stage_sizes):
for j in range(block_size):
strides = (2, 2) if i > 0 and j == 0 else (1, 1)
x = self.block_cls(
self.num_filters * 2**i,
strides=strides,
conv=conv,
norm=norm,
act=self.act
)(x)
x = jnp.mean(x, axis=(1, 2))
x = nn.Dense(self.num_classes, dtype=self.dtype)(x)
x = jnp.asarray(x, self.dtype)
return x
ResNet18 = partial(ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock)
ResNet34 = partial(ResNet, stage_sizes=[3, 4, 6, 3], block_cls=ResNetBlock)
ResNet50 = partial(
ResNet, stage_sizes=[3, 4, 6, 3], block_cls=BottleneckResNetBlock
)
ResNet101 = partial(
ResNet, stage_sizes=[3, 4, 23, 3], block_cls=BottleneckResNetBlock
)
ResNet152 = partial(
ResNet, stage_sizes=[3, 8, 36, 3], block_cls=BottleneckResNetBlock
)
ResNet200 = partial(
ResNet, stage_sizes=[3, 24, 36, 3], block_cls=BottleneckResNetBlock
)
ResNet18Local = partial(
ResNet, stage_sizes=[2, 2, 2, 2], block_cls=ResNetBlock, conv=nn.ConvLocal
)
# Config
class config:
epochs = 20
seed = 42
seed_rng = jax.random.PRNGKey(seed)
batch_size = 128
num_workers = 3
train_ratio = 0.8
lr = 1e-2
weight_decay = 1e-2
image_size = 32
num_classes = 10
dataset_path = './data'
model = ResNet200(num_classes=num_classes)
def seed_everything():
np.random.seed(config.seed)
torch.manual_seed(config.seed)
seed_everything()
# Dataset
def numpy_collate(batch):
if isinstance(batch[0], np.ndarray):
return np.stack(batch)
elif isinstance(batch[0], (tuple,list)):
transposed = zip(*batch)
return [numpy_collate(samples) for samples in transposed]
else:
return np.array(batch)
def transform(image, mode):
image = np.array(image, dtype=np.float32)
image = albu.Compose([
albu.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
])(image=image)['image']
if mode=='train':
image = albu.Compose([
albu.HorizontalFlip(p=0.1),
albu.RandomResizedCrop(height=32, width=32, scale=(0.8,1.0), ratio=(0.9,1.1), p=0.1),
albu.Blur(blur_limit=10, p=0.1),
albu.RandomBrightnessContrast(brightness_limit=0.1, contrast_limit=0.1, brightness_by_max=True, p=0.1),
albu.RandomGridShuffle(p=0.2),
albu.Cutout(num_holes=4, max_h_size=4, max_w_size=4, fill_value=0, p=0.2)
])(image=image)['image']
return image
return image
dataset = torchvision.datasets.CIFAR10(root=config.dataset_path, train=True, download=True, transform=partial(transform, mode='train'))
train_dataset, _ = torch.utils.data.random_split(
dataset, [len(dataset), 0]
)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, collate_fn=numpy_collate, num_workers=config.num_workers)
val_dataset = torchvision.datasets.CIFAR10(root=config.dataset_path, train=False, download=True, transform=partial(transform, mode='test'))
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=config.batch_size, shuffle=True, collate_fn=numpy_collate, num_workers=config.num_workers)
# Show Image
def imshow(img):
img = img / 2 + 0.5
npimg = img.numpy()
print(npimg.shape)
plt.imshow(npimg)
plt.show()
images, labels = train_dataset[0]
print(dataset.classes)
imshow(torchvision.utils.make_grid(torch.from_numpy(images).clone()))
print('%5s' % dataset.classes[labels])
# Model Initialization
@jax.jit
def initialize(params_rng):
init_rngs = {'params': params_rng}
input_shape = (1, config.image_size, config.image_size, 3)
variables = config.model.init(init_rngs, jnp.ones(input_shape, jnp.float32), train=False)
return variables
variables = initialize(config.seed_rng)
# Optimizer
train_steps_per_epoch = len(train_loader)
num_train_steps = train_steps_per_epoch * config.epochs
schedule_fn = optax.cosine_onecycle_schedule(transition_steps=num_train_steps, peak_value=config.lr)
tx = optax.adamw(learning_rate=schedule_fn, weight_decay=config.weight_decay)
# TrainState
class TrainState(train_state.TrainState):
batch_stats: Any
state = TrainState.create(
apply_fn = config.model.apply,
params = variables['params'],
batch_stats = variables['batch_stats'],
tx = tx
)
# Loss
@jax.jit
def cross_entropy_loss(logits, labels):
one_hot_labels = common_utils.onehot(labels, num_classes=config.num_classes)
loss = optax.softmax_cross_entropy(logits=logits, labels=one_hot_labels)
loss = jnp.mean(loss)
return loss
# Metrics
@jax.jit
def compute_metrics(logits, labels):
loss = cross_entropy_loss(logits, labels)
accuracy = jnp.mean(jnp.argmax(logits, -1) == labels)
metrics = {
'loss': loss,
'accuracy': accuracy
}
return metrics
def log_metrics(metrics, history, mode):
loss = np.mean([m['loss'] for m in metrics])
history[f'{mode}_loss'].append(loss)
print(f'{mode}_loss:', loss)
accuracy = np.mean([m['accuracy'] for m in metrics])
history[f'{mode}_accuracy'].append(accuracy)
print(f'{mode}_accuracy:', accuracy)
return history
# Train
@jax.jit
def train_step(state, batch):
batch = {'image':batch[0], 'label':batch[1]}
def loss_fn(params):
variables = {'params': params, 'batch_stats': state.batch_stats}
logits, new_model_state = state.apply_fn(
variables,
batch['image'],
train=True,
mutable='batch_stats'
)
loss = cross_entropy_loss(logits, batch['label'])
return loss, (new_model_state, logits)
grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
aux, grads = grad_fn(state.params)
new_model_state, logits = aux[1]
metrics = compute_metrics(logits, batch['label'])
new_state = state.apply_gradients(grads=grads, batch_stats=new_model_state['batch_stats'])
return new_state, metrics
@jax.jit
def eval_step(state, batch):
batch = {'image':batch[0], 'label':batch[1]}
variables = {'params': state.params, 'batch_stats': state.batch_stats}
logits = state.apply_fn(variables, batch['image'], train=False, mutable=False)
metrics = compute_metrics(logits, batch['label'])
return metrics
def train(state, train_loader, val_loader):
history = defaultdict(list)
for epoch in range(1, config.epochs+1):
print(f'{epoch}/{config.epochs}')
train_metrics = []
print('train')
for i, batch in enumerate(tqdm(train_loader)):
state, metrics = train_step(state, batch)
train_metrics.append(metrics)
history = log_metrics(train_metrics, history, 'train')
val_metrics = []
print('validation')
for i, batch in enumerate(tqdm(val_loader)):
metrics = eval_step(state, batch)
val_metrics.append(metrics)
history = log_metrics(val_metrics, history, 'val')
return history
history = train(state, train_loader, val_loader)
# Plot
fig = go.Figure(data=[
go.Scatter(x=np.arange(1, len(history['train_loss'])+1), y=history['train_loss'], name='train_loss'),
go.Scatter(x=np.arange(1, len(history['val_loss'])+1), y=history['val_loss'], name='val_loss')
])
fig.update_layout(title=f'loss')
fig.show()
fig = go.Figure(data=[
go.Scatter(x=np.arange(1, len(history['train_accuracy'])+1), y=history['train_accuracy'], name='train_accuracy'),
go.Scatter(x=np.arange(1, len(history['val_accuracy'])+1), y=history['val_accuracy'], name='val_accuracy')
])
fig.update_layout(title=f'accuracy')
fig.show()
参考にしたWebページ
https://qiita.com/mako0715/items/0c9499a7dc124d6d2f40
https://qiita.com/Takayoshi_Makabe/items/79c8a5ba692aa94043f7
https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial9/AE_CIFAR10.html
https://juliusruseckas.github.io/ml/flax-cifar10.html
https://zenn.dev/inoichan/articles/3509d3f2e9211e
https://github.com/google/flax/blob/main/examples/imagenet/models.py
https://www.tc3.co.jp/jaxflax-introduction-with-mnist/
など