LoginSignup
3
4

More than 3 years have passed since last update.

trax: focuses on clear code and speed 深層学習ライブラリの試用

Last updated at Posted at 2021-01-23

きっかけと概要

諸事情により必要となったため、Tensor2Tensor後継のtrax深層学習ライブラリ試用。メモ。

trax:Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively used and maintained in the Google Brain team.
https://github.com/google/trax
https://trax-ml.readthedocs.io/en/latest/#

  • tensorflowより簡単に
  • Tensor2Tensorより簡単に
  • Pytorchより簡潔に
  • kerasより高機能かつ簡潔に
  • 論文のアルゴリズム表記(擬似コード)と同じような書式で実装可能(他ライブラリのように過去の遺産を記述する必要がない)

image.png
https://arxiv.org/abs/1905.05621
擬似コード:
http://mirrors.ctan.org/macros/latex/contrib/algorithms/algorithms.pdf

  • 速い(そうだが)
  • 小メモリ大入力可transformerであるReformer実装(HuggingFace transformersにも実装あり)
    Reformer: The Efficient Transformer
    https://ai.googleblog.com/2020/01/reformer-efficient-transformer.html

  • ドキュメントは少ない

  • jaxを用いているためwindowsでは動かない(wsl2で動く)

  • Pytorch-lightning+kerasという印象。debug沼にハマっていないためか、今のところ、kerasと比較したメリットを感じられていない。既存モデルを使うならHuggingFace transformersやkeras、論文実装ならPytorchかtrax、という印象はある。
    →慣れると思考の階層が整理されており簡単かなと感じる。Pytorch-lightningのような、継承が原因であろう、非明示的にも記載できるブラックボックス感はない。

参考

環境

Windows10 wsl2 Ubuntu18.04LTS
Python3.7
trax==1.3.1

インストール

pip install trax==1.3.1

実装例

  • コピペで動く短いコードとなるように記載。
  • 諸事情により一部記載していない。記載していない部分には▲を付している。

試用データ

テキストの2値分類をイメージした次のようなデータを試用に用いた。
*適当な公開データセットから試用データを持ってきたほうが見栄えは良いのだが、入力の書式がわかりにくくなるので今回は避ける。
*修正しながら貼り付けたので間違えている部分があるかもしれない。

train_pos = ['i am happy']*10
train_neg = ['she is sad']*10
train_y = np.append(np.ones(10), np.zeros(10))

前処理は長くなるため省略。テキストが次のように変換されたと思いねぇ。

tensor_pos = []
tensor_neg = []
for i in range(10):
    tensor_pos.append([1,2,3])
    tensor_neg.append([4,5,6])
vocab = {'i':1, 'am':2, 'happy':3 , 'she':4 , 'is':5 , 'sad':6 }

DataLoader定義

本題ではないので適当に。
*import trax.data as data
data.Serial()を使えば、下記tl.Serial()と同じように前処理を積み重ねるように記述しreturns a generatorとできる。

import numpy as np

def data_generator(tensor_pos, tensor_neg, batch_size):
    n_to_take = batch_size // 2
    counter = 0
    stop = False
    while True:
        batch = []
        for i in range(n_to_take):
            if counter>=len(data):
                stop = True
                break
            batch.append(tensor_pos[i])
            batch.append(tensor_neg[i])
            counter += 1
        if stop==True:
            break
        inputs = np.array(batch)
        batch_y = []
        for i in range(len(batch)//2):
            batch_y += [1]
            batch_y += [0]
        targets = np.array(batch_y)
        example_weights = np.ones_like(targets)
        yield inputs, targets, example_weights

New Layer定義

必要に応じ

class Layer(object):
    def __init__(self):
        self.weights = None
    def forward(self, x):
        raise NotImplementedError
    def init_weights_and_state(self, input_signature, random_key):
        pass
    def init(self, input_signature, random_key):
        self.init_weights_and_state(input_signature, random_key)
        return self.weights
    def __call__(self, x):
        return self.forward(x)

class Dense(Layer): #標準の関数tl.Dense()があるが例として
    def __init__(self, n_units, init_stdev=0.1):
        self._n_units = n_units
        self._init_stdev = init_stdev
    def forward(self, x):
         #xとself.weight
        return dense
    def init_weights_and_state(self, input_signature, random_key):
         #shapeを見よ
        return self.weights
    #call return self.forward(x) ほか承継

model定義

入力→embed層→平均層→全結合層→出力

import trax
from trax import layers as tl

vocab_size = len(vocab)
embedding_dim = 256
output_dim = 2

model = tl.Serial(
            tl.Embedding(vocab_size,d_feature=embedding_dim), #githubより。以下同じ
            tl.Mean(axis=1), #2,3,256 2,256
            tl.Dense(output_dim),
            tl.LogSoftmax())

*reformerモデルの場合

model = trax.models.reformer.ReformerLM(
            vocab_size=vocab_size,
            n_layers=n_layers,
            mode=mode,
            attention_type=tl.SelfAttention)

中身は

Serial[
  ShiftRight(1)
  Embedding_train_512
  Dropout
  PositionalEncoding
  Dup_out2
  ReversibleSerial_in2_out2[
    ReversibleHalfResidualV2_in2_out2[
      Serial[
        LayerNorm
      ]
      SelfAttention
    ]
    ReversibleSwap_in2_out2
    ReversibleHalfResidualV2_in2_out2[
      Serial[
        LayerNorm
        Dense_2048
        Dropout
        FastGelu
        Dense_512
        Dropout
      ]
    ]
    ReversibleSwap_in2_out2
    ReversibleHalfResidualV2_in2_out2[
      Serial[
        LayerNorm
      ]
      SelfAttention
    ]
    ReversibleSwap_in2_out2
    ReversibleHalfResidualV2_in2_out2[
      Serial[
        LayerNorm
        Dense_2048
        Dropout
        FastGelu
        Dense_512
        Dropout
      ]
    ]
    ReversibleSwap_in2_out2
  ]
  Concatenate_in2
  LayerNorm
  Dropout
  Dense_train
  LogSoftmax
]

訓練

from trax.supervised import training

batch_size = 2

train_task = training.TrainTask(
                labeled_data = data_generator(tensor_pos, tensor_neg, batch_size),
                loss_layer = tl.CrossEntropyLoss(),
                optimizer = trax.optimizers.Adam(0.01), 
                n_steps_per_checkpoint=10,)
eval_task = training.EvalTask(
                labeled_data = data_generator(tensor_pos, tensor_neg, batch_size),#*省略の都合上、evalはtrainと同一データとしている
                metrics = [tl.CrossEntropyLoss(), tl.Accuracy()],)

training_loop = training.Loop(model, train_task, eval_task=eval_task, output_dir='./')
training_loop.run(n_steps=10)

Step 1: train CrossEntropyLoss | 0.41486400
Step 1: eval CrossEntropyLoss | 0.09423853
Step 1: eval Accuracy | 1.00000000
Step 10: train CrossEntropyLoss | 0.01481570
Step 10: eval CrossEntropyLoss | 0.00004220
Step 10: eval Accuracy | 1.00000000

予測例

tmp_inputs, _t, _w = next(data_generator(tensor_pos, tensor_neg, batch_size))
tmp_pred = training_loop.eval_model(tmp_inputs)

display(tmp_inputs, tmp_pred)
for tmp in tmp_pred:
    if tmp[0] < tmp[1]:
        print('pos')
    else:
        print('neg')

array(
  [[1, 2, 3],
  [4, 5, 6]])

DeviceArray(
  [[-9.4024553e+00, -8.2492828e-05],
  [-1.9073486e-06, -1.3045208e+01]],
  dtype=float32)

pos
neg

3
4
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
4