以下作業メモ
こんな感じのレポが作れる
https://huggingface.co/totti0223/resnet18_fractaldb_10000
前提
token関係はクリアしていること
アップロード
モデル部分
- 本稿のコードはfractalDB pretrained resnet (MIT ライセンス)リンクをhuggingface経由で利用するための手順
- pytorchモデル(nn.Module)をPyTorchModelHubMixinで継承するラッパーを書く。モデル構造を最初から定義する場合は最初から以下のようにしてもよい
class CNN(nn.Module, PyTorchModelHubMixin):
....
import os
import torch
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from torchvision.models import resnet18, resnet34, resnet50, resnet101, resnet152,
class HF_Write_Wrapper(nn.Module, PyTorchModelHubMixin):
def __init__(self, base_model):
super().__init__()
self.model = base_model
def forward(self, x):
return self.model(x)
def _get_model(arch, pretrained=True, dataset='10000', **kwargs):
num_classes = int(dataset)
if arch.startswith('res'):
resnet_factory = {
'res18': resnet18,
'res34': resnet34,
'res50': resnet50,
'res101': resnet101,
'res152': resnet152
}
base_model = resnet_factory[arch](**kwargs)
else:
raise ValueError(f"Unsupported architecture: {arch}")
base_model.fc = nn.Linear(base_model.fc.in_features, num_classes) # fc層を書き換え
if pretrained:
state_dict = _get_state_dict(arch, dataset)
base_model.load_state_dict(state_dict)
base_model.eval()
model = HF_Write_Wrapper(base_model)
return model
def resnet18_fractaldb(pretrained=True, dataset='10000', **kwargs):
"""ResNet-18 model trained on FractalDB dataset"""
return _get_model('res18', pretrained, dataset, **kwargs)
def resnet34_fractaldb(pretrained=True, dataset='10000', **kwargs):
"""ResNet-34 model trained on FractalDB dataset"""
return _get_model('res34', pretrained, dataset, **kwargs)
def resnet50_fractaldb(pretrained=True, dataset='10000', **kwargs):
"""ResNet-50 model trained on FractalDB dataset"""
return _get_model('res50', pretrained, dataset, **kwargs)
def resnet101_fractaldb(pretrained=True, dataset='10000', **kwargs):
"""ResNet-101 model trained on FractalDB dataset"""
return _get_model('res101', pretrained, dataset, **kwargs)
def resnet152_fractaldb(pretrained=True, dataset='10000', **kwargs):
"""ResNet-152 model trained on FractalDB dataset"""
return _get_model('res152', pretrained, dataset, **kwargs)
メタデータカード
huggingfaceではREADME.mdにhuggingfaceがタグや情報を読み取るためのyamlが記述されている必要がある。テンプレートはpython apiからアクセスできる
from huggingface_hub import ModelCard, ModelCardData
card_data = ModelCardData(language='en', license='mit', library='torch')
txt = "arbitrary text"
content = f"""
---
{ card_data.to_yaml() }
---
# My Model Card
{txt}
"""
card = ModelCard(content)
ModelCard.validate(card)
アップロード
modelとcardをそれぞれpush_to_hubでレポジトリ名(自動で新規作成される)を引数にいれればアップロードされる。tokenに権限がないとエラーが出るので注意。
archs = [resnet18_fractaldb,resnet34_fractaldb, resnet50_fractaldb, resnet101_fractaldb, resnet152_fractaldb]
for arch in archs:
print(arch.__name__)
repo_id = f"totti0223/{arch.__name__}_10000"
model = arch()
model.push_to_hub(repo_id)
card.push_to_hub(repo_id)
ダウンロード
同様にPyTorchModelHubMixinを継承したクラスを作成する。from_pretrainedでレポジトリ名を指定。ここではresnetのバックボーンとレポジトリを引数管理したいので以下のようにしているが、self.model = some_model()のように直接呼び出したりしてもよい。
from torchvision import models
from huggingface_hub import PyTorchModelHubMixin
import torch.nn as nn
class HF_Wrapper(nn.Module, PyTorchModelHubMixin):
def __init__(self, model_name="resnet18", num_classes=10000):
super().__init__()
self.model = getattr(models, model_name)(num_classes=num_classes)
def forward(self, x):
return self.model(x)
model_id = "totti0223/resnet18_fractaldb_10000"
model_name = "resnet18"
num_classes = 10000
model = HF_Wrapper(model_name, num_classes=num_classes).from_pretrained(model_id)
model.eval()