2
2

More than 1 year has passed since last update.

ディープラーニングでよく使うコード

Last updated at Posted at 2023-05-03

目的

個人的にディープラーニングの学習時によく使うコードを備忘録として記載

Pytorch

モデルの保存

import torch
# "model"はモデルのインスタンス名、その後の引数で保存用パスを記載する
torch.save(model.state_dict(), 'xxx/model.pt')

保存したモデルの読み込み

import torch
from torch.utils.data import DataLoader
import torchvision.transforms as T

# モデルのインスタンス化(保存時と同じモデル、引数の指定が必要)
model = Model()
# モデルの読み込み
model.load_state_dict(torch.load('xxx/model.pt'))
# 評価モードにする
model.eval()

# 読み込んだモデルで予測する為のデータを用意
## 以下はCIFAR10での例だが、実際は学習時と同形式のデータで用意する
transforms = T.Compose((
    T.ToTensor(),
    T.Normalize(mean=channel_mean, std=channel_std),
    ))

dataset = torchvision.datasets.CIFAR10(
        root='data', train=False, download=True,
        transform=transforms)

loader = DataLoader(dataset, batch_size=32)

'''
1バッチサイズ分のみ予測
'''
# バッチサイズ分のデータを取り出し
tmp = iter(loader)
x, y = next(tmp) # x:予測用データ、y:正解ラベル

# 予測値の計算
## (1) PytorchのLinear関数使った場合(基本(2)を使えば良い)
def get_device(self):
    return self.linear.weight.device # self.linearはクラスで指定したLinear関数

x = x.to(model.get_device()) # モデルのデバイスを指定
pred = model(x)

## (2) 単純にデバイス指定する場合
x = x.to("cuda") # GPUの場合
pred = model(x)

Google Colab

Googleドライブをマウント

from google.colab import drive
drive.mount('/content/drive')

デバイス設定の確認

Google Colabで割り当てられているGPUの種類を見るときに使う

# TensorFlow経由でデバイス設定の確認
from tensorflow.python.client import device_lib
device_lib.list_local_devices()

モジュールのリロード

モジュール内のファイルを修正して、再起動せずに反映させる方法
(これやっても反映されないことはちょくちょくある。そういうときは諦めてランタイムの再起動をする)

# devディレクトリ直下にあるutilsというモジュールを対象とした場合
base_dir = './drive/MyDrive/dev/'

import sys
sys.path.append(base_dir)

# インポート
import utils

# リロード
import importlib
importlib.reload(utils)
2
2
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
2