1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Pytorch備忘録

1
Last updated at Posted at 2026-01-18
💡

RTX5000番台(Blackwell architechture)はTensorflow2.xではネイティブで対応していない(Googleに見捨てられたフレームワーク( ´∀` )).そのため,Pytorchに移行しています.その際に,学んだことなどを書きためておきます.

前提

torchでは,CPU上のデータとGPU上のデータで計算をすることができません.そのため,使用するデバイスを明示するべきです.これはGPUとCPUの衝突を防ぐためです.

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu)

データセットの読み込み

import pandas as pd
import torch
import numpy

#csvの読み込み
df = pd.read_csv('hoge.csv', encoding = 'UTF-8')
#dataframeをndarrayに変換
x = x.to_numpy().astype("float32")
#ndarrayをtensorに変換
x = torch.from_numpy(x).to(device)
#学習用と検証用にsplit
train_data, val_data = torch.utils.data.random_split(x, [0.8 ,0.2])

基本のひな型

class Net(nn.Module);
	def __init__(self);
		super().__init__()
		
	def forward(self,x);
		
		return x

モデルの保存

モデル構造全体を保存する場合

#保存
scripted_model = torch.jit.script(model)
scripted_model.save('sample_model.pth')

#読み込み
load_model = torch.jit.load('sample_model.pth')

重みを保存する場合

#保存
torch.save(model.state_dict(), 'model_weight.pth')

#読み込み
weight_from_model = Net()
weight_from_model.load_state_dict(torch.load('model_weight.pth')

モデルの可視化

Pytorchで作成したモデルの可視化のためのライブラリとして,Torchvistaというのがあります.非常に見やすく可視化してくれるのでおすすめ.

Reference

ひな型

import torch 
import torch.nn as nn
from torchvista import trace_model

#今回はサンプルのために簡単なCNNを作ってみます

class TestModel(nn.Module):
	def __init__(self, in_ch, out_ch):
		super().__init__()
		self.seq = nn.Sequential(
		nn.BatchNorm2d(in_ch),
		nn.Conv2d(in_ch, out_ch, kernel_size = 3, padding = 1),
		nn.ReLU(),
		)
	def foward(self,X):
		out = self.seq(X)
		return out
	
model = Testmodel(192, 128)
x = torch.rand(1, 192, 28, 28)
trace _model(model, x)

学習の効率化

学習をしているとどうしてもOut Of Memoryになってしまうことがあります.これを解決するためにCheckpointingを使います.通常の学習ではForwardでの計算結果をRAMにすべて保存しますが,Checkpointingでは,一部のチェックポイントのみを残し,ほかの部分を捨てます.これによりRAMを解放できます.Backward中に勾配が必要になったら,その部分だけForwardで再計算を行って,作り直します.乱数が変わってしまうと計算結果が変わるのでは?と思われるかもしれませんが,Pytrochは内部でこの乱数シードの保存・復元を自動的に処理してくれるようです.少し計算時間は伸びますが,VRAM不足を劇的に解消できます.

基本のひな型

from torch.utils.checkpoint import checkpoint

output = checkpoint(layer, input_tensor, use_reentrant = False)
1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?