きっかけと概要
諸事情により必要となったため、Tensor2Tensor後継のtrax深層学習ライブラリ試用。メモ。
trax:Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the Google Brain team.
https://github.com/google/trax
https://trax-ml.readthedocs.io/en/latest/#
- tensorflowより簡単に
- Tensor2Tensorより簡単に
- Pytorchより簡潔に
- kerasより高機能かつ簡潔に
- 論文のアルゴリズム表記(擬似コード)と同じような書式で実装可能(他ライブラリのように過去の遺産を記述する必要がない)
https://arxiv.org/abs/1905.05621
擬似コード:
http://mirrors.ctan.org/macros/latex/contrib/algorithms/algorithms.pdf
- 速い(そうだが)
小メモリ大入力可transformerであるReformer実装(HuggingFace transformersにも実装あり)
Reformer: The Efficient Transformer
https://ai.googleblog.com/2020/01/reformer-efficient-transformer.htmlドキュメントは少ない
jaxを用いているためwindowsでは動かない(wsl2で動く)
Pytorch-lightning+kerasという印象。debug沼にハマっていないためか、今のところ、kerasと比較したメリットを感じられていない。既存モデルを使うならHuggingFace transformersやkeras、論文実装ならPytorchかtrax、という印象はある。
→慣れると思考の階層が整理されており簡単かなと感じる。Pytorch-lightningのような、継承が原因であろう、非明示的にも記載できるブラックボックス感はない。
参考
Get started with Google Trax for NLP
https://towardsdatascience.com/get-started-with-google-trax-for-nlp-ff8dcd3119cf以下のコードは、元はDLAIのコードです。DLAIの許可のもと変更し開示しています。
環境
Windows10 wsl2 Ubuntu18.04LTS
Python3.7
trax==1.3.1
インストール
pip install trax==1.3.1
実装例
- コピペで動く短いコードとなるように記載。
- 諸事情により一部記載していない。記載していない部分には▲を付している。
試用データ
テキストの2値分類をイメージした次のようなデータを試用に用いた。
*適当な公開データセットから試用データを持ってきたほうが見栄えは良いのだが、入力の書式がわかりにくくなるので今回は避ける。
*修正しながら貼り付けたので間違えている部分があるかもしれない。
train_pos = ['i am happy']*10
train_neg = ['she is sad']*10
train_y = np.append(np.ones(10), np.zeros(10))
前処理は長くなるため省略。テキストが次のように変換されたと思いねぇ。
tensor_pos = []
tensor_neg = []
for i in range(10):
tensor_pos.append([1,2,3])
tensor_neg.append([4,5,6])
vocab = {'i':1, 'am':2, 'happy':3 , 'she':4 , 'is':5 , 'sad':6 }
DataLoader定義
本題ではないので適当に。
*import trax.data as data
data.Serial()を使えば、下記tl.Serial()と同じように前処理を積み重ねるように記述しreturns a generatorとできる。
import numpy as np
def data_generator(tensor_pos, tensor_neg, batch_size):
n_to_take = batch_size // 2
counter = 0
stop = False
while True:
batch = []
for i in range(n_to_take):
if counter>=len(data):
stop = True
break
batch.append(tensor_pos[i])
batch.append(tensor_neg[i])
counter += 1
if stop==True:
break
inputs = np.array(batch)
batch_y = []
for i in range(len(batch)//2):
batch_y += [1]
batch_y += [0]
targets = np.array(batch_y)
example_weights = np.ones_like(targets)
yield inputs, targets, example_weights
New Layer定義
必要に応じ
class Layer(object):
def __init__(self):
self.weights = None
def forward(self, x):
raise NotImplementedError
def init_weights_and_state(self, input_signature, random_key):
pass
def init(self, input_signature, random_key):
self.init_weights_and_state(input_signature, random_key)
return self.weights
def __call__(self, x):
return self.forward(x)
class Dense(Layer): #標準の関数tl.Dense()があるが例として
def __init__(self, n_units, init_stdev=0.1):
self._n_units = n_units
self._init_stdev = init_stdev
def forward(self, x):
▲ #xとself.weight
return dense
def init_weights_and_state(self, input_signature, random_key):
▲ #shapeを見よ
return self.weights
#call return self.forward(x) ほか承継
model定義
入力→embed層→平均層→全結合層→出力
import trax
from trax import layers as tl
vocab_size = len(vocab)
embedding_dim = 256
output_dim = 2
model = tl.Serial(
tl.Embedding(vocab_size,d_feature=embedding_dim), #githubより。以下同じ
tl.Mean(axis=1), #2,3,256 2,256
tl.Dense(output_dim),
tl.LogSoftmax())
*reformerモデルの場合
model = trax.models.reformer.ReformerLM(
vocab_size=vocab_size,
n_layers=n_layers,
mode=mode,
attention_type=tl.SelfAttention)
中身は
Serial[
ShiftRight(1)
Embedding_train_512
Dropout
PositionalEncoding
Dup_out2
ReversibleSerial_in2_out2[
ReversibleHalfResidualV2_in2_out2[
Serial[
LayerNorm
]
SelfAttention
]
ReversibleSwap_in2_out2
ReversibleHalfResidualV2_in2_out2[
Serial[
LayerNorm
Dense_2048
Dropout
FastGelu
Dense_512
Dropout
]
]
ReversibleSwap_in2_out2
ReversibleHalfResidualV2_in2_out2[
Serial[
LayerNorm
]
SelfAttention
]
ReversibleSwap_in2_out2
ReversibleHalfResidualV2_in2_out2[
Serial[
LayerNorm
Dense_2048
Dropout
FastGelu
Dense_512
Dropout
]
]
ReversibleSwap_in2_out2
]
Concatenate_in2
LayerNorm
Dropout
Dense_train
LogSoftmax
]
訓練
from trax.supervised import training
batch_size = 2
train_task = training.TrainTask(
labeled_data = data_generator(tensor_pos, tensor_neg, batch_size),
loss_layer = tl.CrossEntropyLoss(),
optimizer = trax.optimizers.Adam(0.01),
n_steps_per_checkpoint=10,)
eval_task = training.EvalTask(
labeled_data = data_generator(tensor_pos, tensor_neg, batch_size),#*省略の都合上、evalはtrainと同一データとしている
metrics = [tl.CrossEntropyLoss(), tl.Accuracy()],)
training_loop = training.Loop(model, train_task, eval_task=eval_task, output_dir='./')
training_loop.run(n_steps=10)
Step 1: train CrossEntropyLoss | 0.41486400
Step 1: eval CrossEntropyLoss | 0.09423853
Step 1: eval Accuracy | 1.00000000
Step 10: train CrossEntropyLoss | 0.01481570
Step 10: eval CrossEntropyLoss | 0.00004220
Step 10: eval Accuracy | 1.00000000
予測例
tmp_inputs, _t, _w = next(data_generator(tensor_pos, tensor_neg, batch_size))
tmp_pred = training_loop.eval_model(tmp_inputs)
display(tmp_inputs, tmp_pred)
for tmp in tmp_pred:
if tmp[0] < tmp[1]:
print('pos')
else:
print('neg')
array(
[[1, 2, 3],
[4, 5, 6]])
DeviceArray(
[[-9.4024553e+00, -8.2492828e-05],
[-1.9073486e-06, -1.3045208e+01]],
dtype=float32)
pos
neg