0
0

More than 3 years have passed since last update.

Pytorch-LightningのCheckpointでLAN内のminio(S3互換ストレージ)へモデルを保存する方法

Last updated at Posted at 2021-03-14

概要

現在のpytorch-lightning(ver.1.2.3)では、S3への保存を行う際にendpointを指定できないため、LAN内にあるS3互換ストレージなどへ保存を行うことが出来ない。
以下のsetup_endpointを実行することで、fsspecモジュールへ強引にモンキーパッチを当ててエンドポイントを指定する。

方法

  • pytorch-lightning がクラウドストレージへのIOで使用している関数へモンキーパッチを当てEndpointを強引に指定する。
  • 以下のコードをコピペしてsetup_endpointを実行するだけで可能。

コード

from typing import Any
import gorilla
import fsspec

def apply_gorrila(function: Any, module: Any):
    patch = gorilla.Patch(
        module, 
        function.__name__, 
        function, 
        settings=gorilla.Settings(allow_hit=True))
    gorilla.apply(patch)

def setup_endpoint(endpoint_url: str):
    def filesystem(protocol, **storage_options):
        if protocol == "s3":
            storage_options["client_kwargs"] = {
                "endpoint_url": endpoint_url,
            }
        original = gorilla.get_original_attribute(fsspec, 'filesystem')
        return original(protocol, **storage_options)
    def open_files(*args, **kwargs):
        kwargs["client_kwargs"] = {
                "endpoint_url": endpoint_url,
        }
        original = gorilla.get_original_attribute(fsspec.core, 'open_files')
        return original(*args, **kwargs)
    apply_gorrila(filesystem, fsspec)
    apply_gorrila(open_files, fsspec.core)
0
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
0