LoginSignup
3
1

More than 1 year has passed since last update.

【SageMaker】Training JobsでFine-Tuningを行う際にmodel.tar.gzをS3から読み込む

Last updated at Posted at 2022-12-19

以前の記事で、SageMaker Training JobsによるTabBERTモデルの事前学習を行ったので、今回は事前学習の結果model.tar.gzを元にFine-Tuningを実行するJobを作成しました。

基本的には事前学習と同じようなJobなのですが、以下の部分で工夫が必要だったのでメモとしてまとめました。

  • tarファイルの展開
  • 環境変数によるローカルとSageMaker間での引数の切替

なお、Fine-Tuningの元コードについては、以下の記事で解説しています。

tarファイルの展開

事前学習のJobによって、S3上にmodel.tar.gzが保存されています。
image.png

model.tar.gz内には、モデルpytorch_model.bin、設定ファイルconfig.json、辞書ファイルvocab.nb、トークン→id変換ファイルvocab_token2id.binが入っています。
Fine-Tuningではこれらを読み込む必要があるため、Jobを実行するときにtarファイルを展開するような工夫を行います。

まずは、Jobファイルのinput_modelmodel.tar.gzのS3パスを指定します。
これでJob実行時にmodel.tar.gz/opt/ml/input/data/input_model/(model_path)以下に置かれます。

fine-tuning_jobs.ipynb
import sagemaker
from sagemaker.estimator import Estimator

session = sagemaker.Session()
role = sagemaker.get_execution_role()

estimator = Estimator(
    image_uri=<イメージURL>,
    role=role,
    instance_type="ml.g4dn.2xlarge",
    instance_count=1,
    base_job_name="tabformer-opt-fine-tuning",
    output_path="s3://<バケット名>/sagemaker/output_data/fine_tuning",
    code_location="s3://<バケット名>/sagemaker/output_data/fine_tuning",
    sagemaker_session=session,
    entry_point="fine-tuning.sh",
    dependencies=["tabformer-opt"],
    hyperparameters={
        "data_root": "/opt/ml/input/data/input_data/",
        "data_fname": "summary",
        "output_dir": "/opt/ml/model/",
        "model_path": "/opt/ml/input/data/input_model/",
    }
)
estimator.fit({
    "input_data": "s3://<バケット名>/sagemaker/input_data/summary.csv",
    "input_model": "s3://<バケット名>/sagemaker/output_data/pre_training/tabformer-opt-2022-12-16-07-00-45-931/output/model.tar.gz"
})

次にFine-Tuningの実行ファイルtabformer_bert_fine_tuning.py上に以下を記載します。

tabformer_bert_fine_tuning.py
        with tarfile.open(name=path.join(args.model_path, f'model.tar.gz'), mode="r:gz") as mytar:
            mytar.extractall(path.join(args.model_path, f'model'))
            
            token2id_file = path.join(args.model_path, f"model/vocab_token2id.bin")
            vocab_file = path.join(args.model_path, f"model/vocab.nb")
            pretrained_model = path.join(args.model_path, f"model/checkpoint-500/pytorch_model.bin")
            pretrained_config = path.join(args.model_path, f"model/checkpoint-500/config.json")

tarfile.open()model.tar.gzが読み込まれ、mytar.extractall(path.join(args.model_path, f'model'))/opt/ml/input/data/input_model/model/以下に中身が展開されます。

これで、token2id_file = path.join(args.model_path, f"model/vocab_token2id.bin")のように展開されたファイルを読み込むことができるようになります。

環境変数によるローカルとSageMaker間での引数の切替

これでS3上のmodel.tar.gzを読み込めるようになったのですが、ローカルでFineTuningを行う際には読み込み先を変えたいケースもあると思います。

そんなときは、os.getenv('SM_MODEL_DIR')でSageMakerの環境変数SM_MODEL_DIR(コンテナ終了時にS3へアップロードされるディレクトリ)を取得し、ローカルとSageMaker(のJob)で読み込み先を切り替えます。

tabformer_bert_fine_tuning.py
    key = os.getenv('SM_MODEL_DIR')
    
    if key :
        with tarfile.open(name=path.join(args.model_path, f'model.tar.gz'), mode="r:gz") as mytar:
            mytar.extractall(path.join(args.model_path, f'model'))
            
            token2id_file = path.join(args.model_path, f"model/vocab_token2id.bin")
            vocab_file = path.join(args.model_path, f"model/vocab.nb")
            pretrained_model = path.join(args.model_path, f"model/checkpoint-500/pytorch_model.bin")
            pretrained_config = path.join(args.model_path, f"model/checkpoint-500/config.json")
    else :
            vocab_file = path.join(args.model_path, f"vocab.nb")
            token2id_file = path.join(args.model_path, f"vocab_token2id.bin")
            pretrained_model = path.join(args.model_path, f"checkpoint-500/pytorch_model.bin")
            pretrained_config = path.join(args.model_path, f"checkpoint-500/config.json")

参考資料

3
1
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
1