4
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

PyTorch初挑戦の記録、tf.kerasとの比較実装

Last updated at Posted at 2020-12-16

連載中?の会議見える化シリーズこの先に進むにはPyTorchやtorchaudioを使っていく必要がありそうです。そのため、PyTorchに入門してみました。その記録です。

2020/12/20: PyTorchのコードをDataset、DataLoaderを使うものに修正しました。

こちらの記事では、PyTorchとtf.kerasとの比較も述べられています。それぞれお互いの長所を取り込みつつ進化しており、対応力を広げていくにはPyTorchとtf.kerasのバイリンガルになっておく必要がありそうです。Udemyには、Pythonや機械学習の基礎は知ってますよ、という人向けのPyTorch入門コースがあり、参考になります。

実装テーマとしては以下としています。

  • 4次関数(+乱数ノイズ)を多層パーセプトロン(MLP)で近似
  • train〜testの2分割モデル、validationは作らない
  • testの精度は計算せずに、testデータと推論結果をグラフ表示する

4次関数を題材にしたのは、単純でありながら、多層のありがたみがそれなりに出るからです。

tf.kerasは伝統的?なSequential型の実装をしていますが、PyTorch風の実装も可能とのことです。今回の実装方法で比較すると、tf.kerasは記述の抽象度が高く、PyTorchはいろいろ記述が必要ですがpython風の記述であるためカスタマイズがしやすいように感じます。

一方、PyTorchはtorch.Tensor型への依存度が高く、scipyとかpandasとかsklearnとかのnumpyエコシステムを直接活用できないもどかしさがあります。早い段階でtorch.Tensor型を前提としたPyTorchの世界に移る必要があります。具体的には、torch.utils.data.Dataset型にデータをまとめ、前処理はtransformで行い、torch.utils.data.DataLoader型を介して学習やテストを行うのがPyTorch的です。

以下がtf.kerasでの実装です。10か月前の機械学習入門時の記事と比較して、個人的にこなれた感があるなーと思いました。

import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Dense

# データ作成
x = np.linspace(-1.5, 1.5, 1000)
noise = np.random.randn(1000) * 0.1
y = x ** 4 - 2 * x ** 2 + noise
x_train, x_test, y_train, y_test = train_test_split(x, y, train_size=0.8)

# モデル作成
model = Sequential([
    Dense(128, input_dim=1, activation='relu'),
    Dense(64, activation='relu'),
    Dense(32, activation='relu'),
    Dense(16, activation='relu'),
    Dense(1)
])
model.compile(optimizer='adam', loss='mse', metrics='mse')

# 学習
result = model.fit(x_train, y_train, batch_size=128, epochs=100)

# 推論
y_pred = model.predict(x_test)

# 結果表示
plt.scatter(x_test, y_test)
plt.scatter(x_test, y_pred, color='red')
plt.show()

以下がPyTorchでの実装です。

import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, TensorDataset, random_split

# データ作成
x = np.linspace(-1.5, 1.5, 1000)
noise = np.random.randn(1000) * 0.1
y = x ** 4 - 2 * x ** 2 + noise
dataset = TensorDataset(torch.tensor(x).float().view(-1, 1),
                        torch.tensor(y).float().view(-1, 1))
n_train = int(len(dataset) * 0.8)
n_test = len(dataset) - n_train
train_dataset, test_dataset = random_split(dataset, [n_train, n_test])
train_dataloader = DataLoader(train_dataset, batch_size=128 , shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=n_test, shuffle=True)

# モデル作成
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(1, 128), nn.ReLU(),
            nn.Linear(128, 64), nn.ReLU(),
            nn.Linear(64, 32), nn.ReLU(),
            nn.Linear(32, 16), nn.ReLU(),
            nn.Linear(16, 1)
        )
    def forward(self, x):
        x = self.classifier(x)
        return x

model = MLP()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters())

# 学習
for epoch in range(100):
    running_loss = 0.0
    for x_train, y_train in train_dataloader:
        optimizer.zero_grad()
        y_pred = model(x_train)
        loss = criterion(y_pred, y_train)
        loss.backward()
        running_loss += loss.item()
        optimizer.step()
    running_loss /= len(train_dataloader)
    print(f'epoch: {epoch}, loss: {running_loss}')

# 推論
x_test, y_test = iter(test_dataloader).next()
y_pred = model(x_test)

# 結果表示
plt.scatter(x_test, y_test)
plt.scatter(x_test, y_pred.detach(), color='red')
plt.show()

以下が出力結果のグラフです。当然ながら、tf.keras、PyTorchともにほぼ同様の結果が出ます。
スクリーンショット 2020-12-16 20.23.19.png

4
7
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
4
7

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?