0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

pytorchのcustom modelをhuggingfaceにupload/download

Last updated at Posted at 2024-09-05

以下作業メモ
こんな感じのレポが作れる
https://huggingface.co/totti0223/resnet18_fractaldb_10000

前提

token関係はクリアしていること

アップロード

モデル部分

  • 本稿のコードはfractalDB pretrained resnet (MIT ライセンス)リンクをhuggingface経由で利用するための手順
  • pytorchモデル(nn.Module)をPyTorchModelHubMixinで継承するラッパーを書く。モデル構造を最初から定義する場合は最初から以下のようにしてもよい
class CNN(nn.Module, PyTorchModelHubMixin):
....
import os
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152, 

class HF_Write_Wrapper(nn.Module, PyTorchModelHubMixin):
    def __init__(self, base_model):
        super().__init__()
        self.model = base_model
    def forward(self, x):
        return self.model(x)

def _get_model(arch, pretrained=True, dataset='10000', **kwargs):
    num_classes = int(dataset)
    if arch.startswith('res'):
        resnet_factory = {
            'res18': resnet18,
            'res34': resnet34,
            'res50': resnet50,
            'res101': resnet101,
            'res152': resnet152
        }
        base_model = resnet_factory[arch](**kwargs)
    else:
        raise ValueError(f"Unsupported architecture: {arch}")

    base_model.fc = nn.Linear(base_model.fc.in_features, num_classes)  # fc層を書き換え

    if pretrained:
        state_dict = _get_state_dict(arch, dataset)
        base_model.load_state_dict(state_dict)

    base_model.eval()
    model = HF_Write_Wrapper(base_model)

    return model

def resnet18_fractaldb(pretrained=True, dataset='10000', **kwargs):
    """ResNet-18 model trained on FractalDB dataset"""
    return _get_model('res18', pretrained, dataset, **kwargs)

def resnet34_fractaldb(pretrained=True, dataset='10000', **kwargs):
    """ResNet-34 model trained on FractalDB dataset"""
    return _get_model('res34', pretrained, dataset, **kwargs)

def resnet50_fractaldb(pretrained=True, dataset='10000', **kwargs):
    """ResNet-50 model trained on FractalDB dataset"""
    return _get_model('res50', pretrained, dataset, **kwargs)

def resnet101_fractaldb(pretrained=True, dataset='10000', **kwargs):
    """ResNet-101 model trained on FractalDB dataset"""
    return _get_model('res101', pretrained, dataset, **kwargs)

def resnet152_fractaldb(pretrained=True, dataset='10000', **kwargs):
    """ResNet-152 model trained on FractalDB dataset"""
    return _get_model('res152', pretrained, dataset, **kwargs)

メタデータカード

huggingfaceではREADME.mdにhuggingfaceがタグや情報を読み取るためのyamlが記述されている必要がある。テンプレートはpython apiからアクセスできる

from huggingface_hub import ModelCard, ModelCardData

card_data = ModelCardData(language='en', license='mit', library='torch')

txt = "arbitrary text"

content = f"""
---
{ card_data.to_yaml() }
---

# My Model Card
{txt}
"""

card = ModelCard(content)
ModelCard.validate(card)

アップロード

modelとcardをそれぞれpush_to_hubでレポジトリ名(自動で新規作成される)を引数にいれればアップロードされる。tokenに権限がないとエラーが出るので注意。

archs = [resnet18_fractaldb,resnet34_fractaldb, resnet50_fractaldb, resnet101_fractaldb, resnet152_fractaldb]
for arch in archs:
    print(arch.__name__)
    repo_id = f"totti0223/{arch.__name__}_10000"
    model = arch()
    model.push_to_hub(repo_id)
    card.push_to_hub(repo_id)

ダウンロード

同様にPyTorchModelHubMixinを継承したクラスを作成する。from_pretrainedでレポジトリ名を指定。ここではresnetのバックボーンとレポジトリを引数管理したいので以下のようにしているが、self.model = some_model()のように直接呼び出したりしてもよい。

from torchvision import models
from huggingface_hub import PyTorchModelHubMixin
import torch.nn as nn

class HF_Wrapper(nn.Module, PyTorchModelHubMixin):
    def __init__(self, model_name="resnet18", num_classes=10000):
        super().__init__()
        self.model = getattr(models, model_name)(num_classes=num_classes)
    def forward(self, x):
        return self.model(x)

model_id = "totti0223/resnet18_fractaldb_10000"
model_name = "resnet18"
num_classes = 10000
model = HF_Wrapper(model_name, num_classes=num_classes).from_pretrained(model_id)
model.eval()
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?