オンラインツール公開
通常のGPT-2は英語しか対応しないが、多言語対応したオンラインツールを開発した。
設計
実際には3段構成になっており、
- 言語判別と英語への翻訳
- GPT-2本体処理
- 判別されたオリジナル言語への英語からの復元
という、ニセの多言語対応である。
モデルは最近公開された最高性能の1558Mに対応したので、英語では非常に自然な文章が生成されているようだ。
対して、日本語ではいかにも機械翻訳な文章になる。
実装
WebのマイクロサービスAPIとして提供できる形でのコード全体はGithubであとで公開する予定。
Tensorflowレイヤ
Gitで公開されているコードをWebのリクエストを受けるタイミングで呼び出せように、改変した。
TensorflowでGPT-2を実装する部分。
class GPT2Engine():
def __init__(self,
model_name='1558M',
seed=None,
nsamples=1,
batch_size=1,
length=None,
temperature=1,
top_k=0,
top_p=1,
models_dir='models',
return_as_generator=False,
gpu_list='0,1'
):
self.batch_size = batch_size
self.nsamples = nsamples
models_dir = os.path.expanduser(os.path.expandvars(models_dir))
if self.batch_size is None:
self.batch_size = 1
assert self.nsamples % self.batch_size == 0
self.enc = encoder.get_encoder(model_name, models_dir)
hparams = model.default_hparams()
with open(os.path.join(models_dir, model_name, 'hparams.json')) as f:
hparams.override_from_dict(json.load(f))
if length is None:
length = hparams.n_ctx // 2
elif length > hparams.n_ctx:
raise ValueError("Can't get samples longer than window size: %s" % hparams.n_ctx)
from tensorflow.python.client import device_lib
def get_available_gpus():
local_device_protos = device_lib.list_local_devices()
return [x.name for x in local_device_protos if x.device_type == 'GPU']
available_gpus = get_available_gpus()
print(available_gpus)
if len(available_gpus)>0:
self.config = tf.ConfigProto(
gpu_options=tf.GPUOptions(
visible_device_list=gpu_list, # specify GPU number
#per_process_gpu_memory_fraction=0.8,
allow_growth=True
)
)
else:
self.config = tf.ConfigProto()
self.seed = seed
self.hparams = hparams
self.length = length
self.temperature = temperature
self.top_k = top_k
self.top_p = top_p
self.models_dir = models_dir
self.model_name = model_name
graph = tf.get_default_graph()
sess = tf.Session(graph=graph,config=self.config)
self.context = tf.placeholder(tf.int32, [self.batch_size, None])
np.random.seed(self.seed)
tf.set_random_seed(self.seed)
if True:
self.output = sample.sample_sequence(
hparams=self.hparams, length=self.length,
context=self.context,
batch_size=self.batch_size,
temperature=self.temperature, top_k=self.top_k, top_p=self.top_p
)
else:
self.output = sample.sample_sequence(
hparams=self.hparams, length=self.length,
start_token=self.enc.encoder['<|endoftext|>'],
batch_size=self.batch_size,
temperature=self.temperature, top_k=self.top_k, top_p=self.top_p
)[:, 1:]
self.saver = tf.compat.v1.train.Saver()
ckpt = tf.train.latest_checkpoint(os.path.join(self.models_dir, self.model_name))
self.restored = self.saver.restore(sess, ckpt)
print(self.restored)
self.sess = sess
def run(self,raw_text,
length=None,
temperature=1,
top_k=0,
top_p=1
):
print("length={}".format(length))
context_tokens = self.enc.encode(raw_text)
generated = 0
for _ in range(self.nsamples // self.batch_size):
if raw_text is None:
out = self.sess.run(self.output)
else:
out = self.sess.run(self.output, feed_dict={
self.context: [context_tokens for _ in range(self.batch_size)]
})[:, len(context_tokens):]
for i in range(self.batch_size):
generated += 1
text = self.enc.decode(out[i])
print(text)
yield text
翻訳用
今回はAWS Translate を利用している。
def get_translate_text(text,source_lang="auto",target_lang="en"):
AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID', '')
AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY', '')
REGION_NAME = os.getenv('REGION_NAME', '')
translate = boto3.client('translate', aws_access_key_id=AWS_ACCESS_KEY_ID, aws_secret_access_key=AWS_SECRET_ACCESS_KEY, region_name=REGION_NAME)
response = translate.translate_text(
Text=text,
SourceLanguageCode=source_lang,
TargetLanguageCode=target_lang
)
return response['TranslatedText'], response.get('SourceLanguageCode')
文章補正レイヤ
WebエンドポイントとTensorflowクラスの間で、翻訳や文章補正、文字数制限などを行う。
class GPT2WebWrapper():
def __init__(self):
def _int(num):
if num is None:
return None
else:
return int(num)
self.PASSAGE = os.getenv('PASSAGE', '')
self.MODE = os.getenv('MODE', 'PREDICT_UNCONDITIONAL')
self.MODEL_NAME = os.getenv('MODEL_NAME', '1558M')
self.SEED = _int(os.getenv('SEED', None))
self.NSAMPLES = _int(os.getenv('NSAMPLES', 1))
self.BATCH_SIZE = _int(os.getenv('BATCH_SIZE', 1))
self.LENGTH = _int(os.getenv('LENGTH', 30))
self.USER_LANGUAGE = _int(os.getenv('USER_LANGUAGE', None))
self.TEMPERATURE = float(os.getenv('TEMPERATURE', 1.0))
self.TOP_K = _int(os.getenv('TOP_K', 0))
self.TOP_P = _int(os.getenv('TOP_P', 1))
self.MODELS_DIR = os.getenv('MODELS_DIR', 'models')
self.GPU_MEMORY_FRACTION = float(os.getenv('GPU_MEMORY_FRACTION', 1.0))
self.GPU_LIST = os.getenv('GPU_LIST', None)
self.RUN_NAME = os.getenv('RUN_NAME', 'run1')
self.LENGTH_LIMIT = _int(os.getenv('LENGTH_LIMIT', 400))
self.engine = GPT2Engine(
model_name=self.MODEL_NAME,
seed=self.SEED,
nsamples=self.NSAMPLES,
batch_size=self.BATCH_SIZE,
length=self.LENGTH,
temperature=self.TEMPERATURE,
top_k=self.TOP_K,
top_p=self.TOP_P,
models_dir=self.MODELS_DIR,
return_as_generator=True,
gpu_list=self.GPU_LIST
)
def run(self,_in,
length=None,
temperature=None,
top_k=None,
top_p=None,
target_lang="en"
):
empty = False
if _in is None:
_in = 'Sure.'
empty = True
_in_f = ""
_in_l = _in
if len(_in) > self.LENGTH_LIMIT:
_in_f = _in[:-self.LENGTH_LIMIT]
_in_l = _in[-self.LENGTH_LIMIT:]
translated_in, detected_lang = get_translate_text(_in_l, source_lang='auto', target_lang='en')
if not empty:
target_lang = detected_lang
length = length if length is not None else self.LENGTH
temperature = temperature if temperature is not None else self.TEMPERATURE
top_k = top_k if top_k is not None else self.TOP_K
top_p = top_p if top_p is not None else self.TOP_P
_out = ""
iter = self.engine.run(translated_in,
length=length,
temperature=temperature,
top_k=top_k,
top_p=top_p
)
while iter:
try:
_out = str(iter.__next__())
except:
#traceback.print_exc()
break
_out = translated_in+" "+_out
_out = _out.replace("<|endoftext|>"," ")
translated_out, input_lang = get_translate_text(_out, source_lang='en', target_lang=target_lang)
translated_out = translated_out.replace("\n"," ")
translated_out = _in_f + translated_out
print(translated_out)
return _out, translated_in, translated_out, detected_lang
課題
モデルとして日本語ネイティブ対応のほうが、より自然になるだろう。
GPT-2はもともと英語圏のニュースサイトなどを学習データとして利用しているため、そこで語られるような文章を生成するのは得意だ。GPT-2は基本的には、最初に文章を与え、その後に来るのにふさわしい文章を推測する。例えば日本語で「トランプ大統領は」と入れて生成すると、少しはまともな文章になる。
一方で、英語圏の文化から離れた単語を与えた場合、不自然になりやすい。自然な日本語を得るためには、翻訳を使用しないことはもちろんだが、日本語圏の文章を学習させることも非常に重要だと感じる。
ちなみに生成された文章からさらに再帰的に文章を生成可能だが、英語以外だと、翻訳が繰り返し行われることで、どんどん劣化していくことも問題だ。
しかし、日本語ネイティブモデルのニーズってあるのだろうか。
今後の文章AI生成の展望
現時点の大規模教師なし学習での文章生成は自然さという意味では、高くなってきたが、何かに適応できるかというとイマイチであると感じる。
私の個人的意見であるが、以下のよう課題感がある。
- 常識や事実をいかに学習させられるか
- 目的を持った文章を生成できるか
このような部分が解決されれば、ビジネス的にも価値がでてくると思う。
おわりに
GPT-2関連ツールは絶賛、開発中なので、興味があれば、こちらの記事も読んでほしい。