1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

SageMakerでカスタムpytorchモデル実装

Last updated at Posted at 2023-05-25

SageMakerでカスタムpytorchモデル実装

記事の目的

SageMakerの分散トレーニング、デプロイオーケストレーションシステムを活用するための

自作pytorchプログラムの改修方法理解

対象者

基礎的な機械学習知識を所有しており、

SageMakerの分散トレーニング、デプロイオーケストレーションシステムを活用したい人

この記事を読み終わるまでの時間

10m

基本的な処理の流れ

事前セットアップ

import sagemaker

sagemaker_session = sagemaker.Session()

bucket = sagemaker_session.default_bucket()
prefix = 'sagemaker/DEMO-pytorch-mnist'

role = sagemaker.get_execution_role()

使用、訓練データ取得

from torchvision import datasets, transforms

datasets.MNIST('data', download=True, transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
]))

s3へのデータアップロード

inputs = sagemaker_session.upload_data(path='data', bucket=bucket, key_prefix=prefix)
print('input spec (in this case, just an S3 path): {}'.format(inputs))

sageMakerのエコシステムを用いて訓練、デプロイするため、sagemakerに内包されたPytorchクラスを生成

from sagemaker.pytorch import PyTorch

estimator = PyTorch(entry_point='mnist.py',
                    role=role,
                    framework_version='1.1.0',
                    train_instance_count=2,
                    train_instance_type='ml.c4.xlarge',
                    hyperparameters={
                        'epochs': 6,
                        'backend': 'gloo'
                    })

訓練の実施

estimator.fit({'training': inputs})

本題:SageMakerエコシステムに適合するためのmnist.pyの書き方

ポイント

  • input_fn、predict_fn、model_fn、output_fnの4兄弟を作成しない場合は以下のdefualt_**関数4種類が呼び出される。
  • つまりmodel_fnに関しては、実装がないとエラーとなる。
  • 他のものについては、default_**の処理で問題なければ実装の必要はない。

default関数 default_input_fn

# https://github.com/aws/sagemaker-scikit-learn-container/blob/master/ src/sagemaker_sklearn_container/handler_serving.py より抜粋

        @staticmethod
        def default_input_fn(input_data, content_type):
            """Takes request data and de-serializes the data into an object for prediction.
                When an InvokeEndpoint operation is made against an Endpoint running SageMaker model server,
                the model server receives two pieces of information:
                    - The request Content-Type, for example "application/json"
                    - The request data, which is at most 5 MB (5 * 1024 * 1024 bytes) in size.
                The input_fn is responsible to take the request data and pre-process it before prediction.
            Args:
                input_data (obj): the request data.
                content_type (str): the request Content-Type.
            Returns:
                (obj): data ready for prediction.
            """
            np_array = decoder.decode(input_data, content_type)
            if len(np_array.shape) == 1:
                np_array = np_array.reshape(1, -1)
            return np_array.astype(np.float32) if content_type in content_types.UTF8_TYPES else np_array

上記で呼び出されるdecoder関数。

def decode(obj, content_type):
    """Decode an object that is encoded as one of the default content types.

    Args:
        obj (object): to be decoded.
        content_type (str): content type to be used.

    Returns:
        object: decoded object for prediction.
    """
    try:
        decoder = _decoder_map[content_type]
        return decoder(obj)
    except KeyError:
        raise errors.UnsupportedFormatError(content_type)

補足:_decoder_mapがcontent_typeに応じたconvertメソッドを呼び出し、objを変換したものを返す(csv to numpyや、json to numpyなど)

default関数 default_output_fn

# https://github.com/aws/sagemaker-scikit-learn-container/blob/master/ src/sagemaker_sklearn_container/handler_serving.py より抜粋
def default_output_fn(prediction, accept):
    return encoders.encode(prediction, accept), accept# https://github.com/aws/sagemaker-inference-toolkit/blob/master/ src/sagemaker_inference/encoder.pyより抜粋

def encode(array_like, content_type):
    try:
        encoder = _encoder_map[content_type]
        return encoder(array_like)
    except KeyError:
        raise errors.UnsupportedFormatError(content_type)
_encoder_map = {
content_types.NPY: _array_to_npy, content_types.CSV: _array_to_csv, content_types.JSON: _array_to_json,
		}

def _array_to_json(array_like):
    def default(_array_like):
        if hasattr(_array_like, "tolist"):
					return _array_like.tolist()
		return json.JSONEncoder().default(_array_like)

default関数 default_model_fn (model_fnがない時のexception吐き出しだけ)

def default_model_fn(model_dir):
	"""Loads a model. For Scikit-learn, a default function to load a model is not provided. Users should provide customized model_fn() in script.
	Args:
	           model_dir: a directory where model is saved.Returns: A Scikit-learn model.
	"""
	raise NotImplementedError(textwrap.dedent("""
	Please provide a model_fn implementation.
	See documentation for model_fn at https://github.com/aws/sagemaker-python-sdk """))

default関数 default_predict_fn(実装頻度は低め)

# https://github.com/aws/sagemaker-scikit-learn-container/blob/master/ src/sagemaker_sklearn_container/handler_service.py より抜粋
def default_predict_fn(input_data, model):
"""A default predict_fn for Scikit-learn. Calls a model on data deserialized in input_fn. Args:
          input_data: input data (Numpy array) for prediction deserialized by input_fn
model: Scikit-learn model loaded in memory by model_fn Returns: a prediction
"""
	output = model.predict(input_data)
	return output

boto3からのリクエスト例(invoke_endpointを実施)

response = smr_client.invoke_endpoint( 
	EndpointName=endpoint_name, 
	ContentType='text/csv', 
	Accept='text/csv', 
	Body='1,2,3,10000'
)
predictions = response['Body'].read().decode('utf-8') print(predictions)

上記のポイント

  • ContentTypeをinput_fnで使用
  • Acceptをoutput_fnで使用
  • 実際にmodelに渡されるのはBody
  • 処理の流れは以下
    • (エンドポイント立ち上げ時)model_fn
    • (API実行時)input_fn→predict_fn→output_fn

fnメソッド4種類ののカスタム例(上記API呼び出しを行う想定での改修)

# モデル読み込み
def model_fn(model_dir):
  with open(os.path.join(model_dir,'my_model.txt')) as f:
      hello = f.read()[:-1]
  return hello

# 前処理
def input_fn(input_data, content_type): 
	if content_type == 'text/csv':
        transformed_data = input_data.split(',')
  else:
      raise ValueError("Illegal content type")
  return transformed_data

# 予測
def predict_fn(transformed_data, model): prediction_list = []
	for data in transformed_data:
        if data[-1] == '1':
          ordinal = f'{data}st'
        elif data[-1] == '2':
          ordinal = f'{data}nd'
				elif data[-1] == '3': 
					ordinal = f'{data}rd'
        else:
            ordinal = f'{data}th'
        prediction = f'{model} for the {ordinal} time'
        prediction_list.append(prediction)
    return prediction_list

# 後処理
def output_fn(prediction_list, accept):
    if accept == 'text/csv':
        response = ''
        for prediction in prediction_list:
            response += prediction + '¥n'
    else:
			raise ValueError("Illegal accept type")
			return response, accept

上記処理の際の想定output:

Hello my great machine learning model for the 1st time 
Hello my great machine learning model for the 2nd time 
Hello my great machine learning model for the 3rd time 
Hello my great machine learning model for the 10000th time

完了

参考記事

default_output_fn参考元

メインロジック参考元:

プログラムコード参考元

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?