以前書いたCloudTPUでtransformer英日翻訳モデルを学習&推論する手順では、CloudTPUでtransformer英日翻訳モデルを学習し、推論も行いました。
今回はCloudTPUで学習済みのtransformerをローカルのDockerコンテナで実行する方法を説明します。
コードはこちらです。
https://github.com/yolo-kiyoshi/transformer_python_exec
前提
GCS、ローカルで以下のディレクトリ構造でファイルが配置されているとします。
GCSディレクトリ
bucket
├── training/
│ └── transformer_ende/
│ ├── checkpoint
│ ├── model.ckpt-****.data-00000-of-00001
│ ├── model.ckpt-****.index
│ └── model.ckpt-****.meta
└── transformer/
└── vocab.translate_jpen.****.subwords
ローカルディレクトリ
リポジトリをcloneします。
git clone https://github.com/yolo-kiyoshi/transformer_python_exec.git
.
├── Dockerfile
├── .env.sample
├── Pipfile
├── Pipfile.lock
├── README.md
├── decode.ipynb
├── docker-compose.yml
├── training/
│ └── transformer_ende/
└── transformer/
準備
Google Credentialファイル
サービスアカウントのCredentialファイル(json)をダウンロードし、README.mdと同一ディレクトリに配置します。
環境変数
.env.sample
を複製&リネームし、.env
を作成します。
# 上で配置したCredentialファイルのパスを記載
GOOGLE_APPLICATION_CREDENTIALS=*****.json
BUDGET_NAME=
# CloudTPUで学習したときと同一の設定
PROBLEM=translate_jpen
DATA_DIR=transformer
TRAIN_DIR=training/transformer_ende/
HPARAMS=transformer_tpu
MODEL=transformer
Dockerイメージの作成とコンテナの起動
以下のコマンドを実行後、http://localhost:8080/lab にアクセスするとJupyter labを操作できます
docker-compose up -d
Notebook
transformer学習結果をGCSからダウンロード
GCSからtransformer学習過程で作成されるcheckpoint
ファイル一式とvocab
ファイルをローカルにダウンロードする。
# GCSからファイルをダウンロードするメソッド(https://cloud.google.com/storage/docs/downloading-objects?hl=ja)
def download_blob(bucket_name, source_blob_name, destination_file_name):
"""Downloads a blob from the bucket."""
storage_client = storage.Client()
bucket = storage_client.get_bucket(bucket_name)
blob = bucket.blob(source_blob_name)
blob.download_to_filename(destination_file_name)
print('Blob {} downloaded to {}.'.format(
source_blob_name,
destination_file_name))
# GCSのファイル一覧取得メソッドを参考
# https://cloud.google.com/storage/docs/listing-objects?hl=ja#storage-list-objects-python
def list_match_file_with_prefix(bucket_name, prefix, search_path):
"""Lists all the blobs in the bucket that begin with the prefix."""
storage_client = storage.Client()
# Note: Client.list_blobs requires at least package version 1.17.0.
blobs = storage_client.list_blobs(bucket_name, prefix=prefix, delimiter=None)
file_list = [blob.name for blob in blobs if search_path in blob.name]
return file_list
# 環境変数を設定
BUDGET_NAME = os.environ['BUDGET_NAME']
PROBLEM = os.environ['PROBLEM']
DATA_DIR = os.environ['DATA_DIR']
TRAIN_DIR = os.environ['TRAIN_DIR']
HPARAMS = os.environ['HPARAMS']
MODEL = os.environ['MODEL']
# checkpointファイルのパス
src_file_name = os.path.join(TRAIN_DIR, 'checkpoint')
dist_file_name = os.path.join(TRAIN_DIR, 'checkpoint')
# checkpointファイルをGCSからダウンロード
download_blob(BUDGET_NAME, src_file_name, dist_file_name)
# checkpointファイルから最新のcheckpointシーケンス(prefix)を取得する
import re
with open(dist_file_name) as f:
l = f.readlines(1)
ckpt_name = re.findall('model_checkpoint_path: "(.*?)"', l[0])[0]
ckpt_path = os.path.join(TRAIN_DIR, ckpt_name)
# 最新のcheckpointに紐づくファイルリストをGCSから取得する
ckpt_file_list = list_match_file_with_prefix(BUDGET_NAME, TRAIN_DIR, ckpt_path)
# checkpoint.variableを一式ダウンロード
for ckpt_file in ckpt_file_list:
download_blob(BUDGET_NAME, ckpt_file, ckpt_file)
# vocabファイルパスをGCSから取得する
vocab_file = list_match_file_with_prefix(BUDGET_NAME, DATA_DIR, os.path.join(DATA_DIR, 'vocab'))[0]
# vocabファイルをGCSからダウンロード
download_blob(BUDGET_NAME, vocab_file, vocab_file)
学習済みtransformerモデルをロード
GCSからダウンロードしたtransformer学習結果をもとに、transformerモデルをロードします。
# 初期化
tfe = tf.contrib.eager
tfe.enable_eager_execution()
Modes = tf.estimator.ModeKeys
import pickle
import numpy as np
from tensor2tensor.data_generators import problem
from tensor2tensor.data_generators import text_problems
from tensor2tensor.utils import registry
# 前処理&学習で定義したPROBLEと同一のClass名にすること
@registry.register_problem
class Translate_JPEN(text_problems.Text2TextProblem):
@property
def approx_vocab_size(self):
return 2**13
enfr_problem = problems.problem(PROBLEM)
# Get the encoders from the problem
encoders = enfr_problem.feature_encoders(DATA_DIR)
from functools import wraps
import time
def stop_watch(func) :
@wraps(func)
def wrapper(*args, **kargs) :
start = time.time()
print(f'{func.__name__} started ...')
result = func(*args,**kargs)
elapsed_time = time.time() - start
print(f'elapsed_time:{elapsed_time}')
print(f'{func.__name__} completed')
return result
return wrapper
@stop_watch
def translate(inputs):
encoded_inputs = encode(inputs)
with tfe.restore_variables_on_create(ckpt_path):
model_output = translate_model.infer(features=encoded_inputs)["outputs"]
return decode(model_output)
def encode(input_str, output_str=None):
"""Input str to features dict, ready for inference"""
inputs = encoders["inputs"].encode(input_str) + [1]
batch_inputs = tf.reshape(inputs, [1, -1, 1])
return {"inputs": batch_inputs}
def decode(integers):
"""List of ints to str"""
integers = list(np.squeeze(integers))
if 1 in integers:
integers = integers[:integers.index(1)]
return encoders["inputs"].decode(np.squeeze(integers))
hparams = trainer_lib.create_hparams(HPARAMS, data_dir=DATA_DIR, problem_name=PROBLEM)
translate_model = registry.model(MODEL)(hparams, Modes.PREDICT)
推論
ロードしたtransformerモデルで推論します。
ローカルで実行すると、1つsentenceで30秒程度時間がかかります。
inputs = "My cat is so cute."
outputs = translate(inputs)
print(outputs)
> 私の猫はとてもかわいいです。