3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

Hugging Faceのtransformersを使用してSMILESに特化した言語モデルを作成した

Last updated at Posted at 2023-01-26

はじめに

Hugging Faceのライブラリーtransformersを用いてZINC・PubChem-10mのデータで事前学習を行い、化合物構造式の線形表記法のひとつであるSMILESに特化した言語モデル(T5, DeBERTa)を作成しました。これらの事前学習モデルを用いることで、化合物の物性や反応、タンパク質との相互作用など様々な予測を行えます。今回は例として逆反応予測(ある化学反応の生成物が与えられたとき、その反応に必要な反応物を予測する)を行いました。この記事では事前学習済みモデルを使い方と、どのようにして事前学習・finetuningを行ったかを紹介したいと思います。実際に収率予測や生成物予測をしたい、finetuningの方法を詳しく知りたいという方はこちらの記事をご覧ください。
コードの詳細については、githubを参照
今回使用した生データは次のリンクからダウンロードできます。(ZINCPubChem-10m, ORD)

目次

  1. 事前学習モデルの使い方
  2. データの前処理
  3. tokenizerの学習
  4. MLMによるモデルの事前学習
  5. optunaを用いたfinetuning時のハイパラ最適化
  6. finetuningによって化学反応予測モデルを作成
  7. まとめ
  8. 参考文献

事前学習モデルの使い方

Hugging Face HubにZINCとPubChem-10mで30epoch事前学習したモデルがアップロードされているため、それをロードすることで簡単に使うことができます。

ZINC-t5の利用
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

model = AutoModelForSeq2SeqLM.from_pretrained('sagawa/ZINC-t5', from_flax=True)
tokenizer = AutoTokenizer.from_pretrained('sagawa/ZINC-t5')

input_ids = tokenizer('Nc1nc(N2CCN(C(=O)COc3ccc(Cl)cc3)CC2)c2c(-c3ccc(F)cc3)csc2n1', return_tensors='pt').input_ids
labels = tokenizer('CCN(C(C)C)C(C)C.Nc1nc(N2CCNCC2)c2c(-c3ccc(F)cc3)csc2n1.O=C(Cl)COc1ccc(Cl)cc1', return_tensors='pt').input_ids

loss = model(input_ids=input_ids, labels=labels).loss
loss.item()

データの前処理

言語モデルの事前学習に使うデータとしてZINCとPubChem-10mをダウンロードし、以下のコードによってデータのcanonical化(正規化)を行いました。

SMILESのcanonical化
from rdkit import Chem
def canonicalize(mol):
    mol = Chem.MolToSmiles(Chem.MolFromSmiles(mol),True)
    return mol

data['smiles'] = data['smiles'].apply(lambda x: canonicalize(x))

tokenizerの学習

mlmによる事前学習を行う前に、ZINCとPubChem-10mのデータそれぞれでtokenizerの学習を行いました。T5とDeBERTaでtokenizerの学習方法が違うのですが、実際のコードは以下のようになります。

tokenizerの学習
def create_normal_tokenizer(dataset, model_name):
    if type(dataset) == datasets.dataset_dict.DatasetDict:
        training_corpus = (
        dataset['train'][i : i + 1000]['smiles']
        for i in range(0, len(dataset), 1000)
        )
    else:
        training_corpus = (
        dataset[i : i + 1000]['smiles']
        for i in range(0, len(dataset), 1000)
        )

    if 'deberta' in model_name:
        # Train tokenizer
        old_tokenizer = AutoTokenizer.from_pretrained(model_name)
        tokenizer = old_tokenizer.train_new_from_iterator(training_corpus, 1000)
    elif 't5' in model_name:
        # https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/t5_tokenizer_model.py
        tokenizer = SentencePieceUnigramTokenizer(unk_token="<unk>", eos_token="</s>", pad_token="<pad>")
        tokenizer.train_from_iterator(training_corpus, 1000)
    
    return tokenizer

dataset = load_dataset('csv', data_files='./data/ZINC-canonicalized.csv')

# DeBERTa tokenizer
tokenizer = create_normal_tokenizer(dataset, 'microsoft/deberta-base')

# T5 tokenizer
tokenizer = create_normal_tokenizer(dataset, 't5')

SMILESでは大文字と小文字で意味が異なるため、T5 tokenizerで使ったSentencePieceUnigramTokenizerでは正規化の際に本来行われる小文字化を廃止しました。

私がSMILESに特化した言語モデルを作っているとのちょうど同時期に、T5ChemというSMILESに特化したmultimodalなモデルが公開されました。このモデルではtokenizerにcharacter-level tokenizerを採用しており、私の実験でも上述のtokenizerよりもcharacter-level tokenizerの方がよい性能を示すことが分かったため、以下のように実装しました。したこととしては入力データを半角スペース区切りにしてtokenizerの学習を行っただけです。

character-level tokenizerの学習
def create_character_level_tokenizer(dataset, model_name):
    df = dataset['train'].to_pandas()
    df['smiles'] = [' '.join(list(i)) for i in df['smiles']]
    dataset = datasets.Dataset.from_pandas(df)    
    
    tokenizer = create_normal_tokenizer(dataset, model_name)
    
    return tokenizer

# DeBERTa tokenizer
tokenizer = create_character_level_tokenizer(dataset, 'microsoft/deberta-base')

# T5 tokenizer
tokenizer = create_character_level_tokenizer(dataset, 't5')

mlmによるモデルの事前学習

ZINCとPubChem-10mのデータそれぞれについて90%をtrainデータ、10%をvalidationデータとしてモデルの事前学習を行いました。事前学習に使ったコードもT5とDeBERTaで異なり、T5はこちらのコードをDeBERTaはこちらのコードを参考にしました。

optunaを使ったfinetuning時のハイパラ最適化

モデルのfinetuningの結果に大きな影響を及ぼしそうなハイパーパラメータであるlearning_rateとweight_decayに関しては、optunaを使って最適化を行いました。
方法としては簡単で、trainer.train()として学習を行う代わりにtrainer.hyperparameter_search()として、引数に探索するパラメーターの範囲を指定することでできます。実際のコードは以下のようになります。

hyperparameter_tuning
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

args = Seq2SeqTrainingArguments(
    output_dir=cfg.output_dir,
    overwrite_output_dir=True, 
    evaluation_strategy=cfg.evaluation_strategy,
    learning_rate=cfg.lr,
    per_device_train_batch_size=cfg.batch_size,
    per_device_eval_batch_size=cfg.batch_size,
    weight_decay=cfg.weight_decay,
    num_train_epochs=cfg.epochs,
    predict_with_generate=True,
    save_total_limit=2,
    fp16=cfg.fp16,
    push_to_hub=False,
    disable_tqdm=True
)

Seq2SeqTrainer.hyperparameter_search = hyperparameter_search

trainer = Seq2SeqTrainer(
    model_init=get_model,
    args=args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

def my_hp_space(trial):
    return {
        "learning_rate": trial.suggest_float("learning_rate", 1e-4, 1e-2, log=True),
        "weight_decay": trial.suggest_float("weight_decay", 0.001, 0.1, log=True),
    }

# start hyperparameter tuning
param = trainer.hyperparameter_search(
    hp_space=my_hp_space,
    n_trials=n_trials
)

しかし、このコードを普通に回すと初めの方は順調に進むのですが途中でout of memoryエラーが起こってしまい、batch_size=2と小さくしてもこの問題は解決されませんでした。どうやらハイパラ探索の間はGPUメモリーが解放されないために起こっているようです。hyperparameter_searchを次のように書き換え、メモリーの開放を行うようにすることで解決できました。

hyperparameter_tuning メモリーの開放
from transformers.trainer_utils import HPSearchBackend, default_hp_space
def run_hp_search_optuna(trainer, n_trials, direction, **kwargs):
    import optuna
    def _objective(trial, checkpoint_dir=None):
        checkpoint = None
        if checkpoint_dir:
            for subdir in os.listdir(checkpoint_dir):
                if subdir.startswith(PREFIX_CHECKPOINT_DIR):
                    checkpoint = os.path.join(checkpoint_dir, subdir)

        if not checkpoint:
            # free GPU memory
            del trainer.model
            gc.collect()
            torch.cuda.empty_cache()
        trainer.objective = None
        trainer.train(resume_from_checkpoint=checkpoint, trial=trial)
        # If there hasn't been any evaluation during the training loop.
        if getattr(trainer, "objective", None) is None:
            metrics = trainer.evaluate()
            trainer.objective = trainer.compute_objective(metrics)
        return trainer.objective

    timeout = kwargs.pop("timeout", None)
    n_jobs = kwargs.pop("n_jobs", 1)
    study = optuna.create_study(direction=direction, **kwargs)
    study.optimize(_objective, n_trials=n_trials, n_jobs=n_jobs)
    best_trial = study.best_trial
    return BestRun(str(best_trial.number), best_trial.value, best_trial.params)

def hyperparameter_search(trainer, n_trials, hp_space = None, compute_objective = None, direction = "minimize", hp_name = None, **kwargs):
    trainer.hp_search_backend = HPSearchBackend.OPTUNA
    trainer.hp_space = default_hp_space[HPSearchBackend.OPTUNA] if hp_space is None else hp_space
    trainer.hp_name = hp_name
    trainer.compute_objective = default_compute_objective if compute_objective is None else compute_objective
    best_run = run_hp_search_optuna(trainer, n_trials, direction, **kwargs)
    trainer.hp_search_backend = None
    return best_run

Seq2SeqTrainer.hyperparameter_search = hyperparameter_search

finetuningによって化学反応予測モデルを作成

finetuningにはordのデータを使用し、化学反応における生成物のSMILESを入力すると反応物のSMILESを出力するSeq2Seqのタスクで学習させました。

finetuning
tokenizer = AutoTokenizer.from_pretrained(CFG.model, return_tensors='pt')
# 化合物どうしをつなぐ'.'をvocabに追加
tokenizer.add_tokens('.')

if CFG.model == 't5':
    model = AutoModelForSeq2SeqLM.from_pretrained(CFG.model, from_flax=True)
    # vocabを追加する場合はtoken_embeddingのサイズが変わるため、サイズを変更
    model.resize_token_embeddings(len(tokenizer))

elif CFG.model == 'deberta':
    model = EncoderDecoderModel.from_encoder_decoder_pretrained(CFG.model, 'roberta-large')
    model.encoder.resize_token_embeddings(len(tokenizer))
    model.decoder.resize_token_embeddings(len(tokenizer))
    config_encoder = model.config.encoder
    config_decoder = model.config.decoder
    config_decoder.is_decoder = True
    config_decoder.add_cross_attention = True
    model.config.decoder_start_token_id = tokenizer.bos_token_id
    model.config.pad_token_id = tokenizer.pad_token_id

tokenized_datasets = dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=dataset['train'].column_names,
    load_from_cache_file=False
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

args = Seq2SeqTrainingArguments(
    CFG.model,
    evaluation_strategy=CFG.evaluation_strategy,
    save_strategy=CFG.save_strategy,
    learning_rate=CFG.lr,
    per_device_train_batch_size=CFG.batch_size,
    per_device_eval_batch_size=CFG.batch_size,
    weight_decay=CFG.weight_decay,
    save_total_limit=CFG.save_total_limit,
    num_train_epochs=CFG.epochs,
    predict_with_generate=True,
    fp16=CFG.fp16,
    disable_tqdm=CFG.disable_tqdm,
    push_to_hub=False,
    load_best_model_at_end=True
)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()

T5はencoder-decoderモデルなのでそのままAutoModelForSeq2SeqLMを使えたのですが、DeBERTaはencoderモデルでかつdecoderがHugging Faceで公開されていないためAutoModelForSeq2SeqLMを使えませんでした。そのため、encoderは事前学習済みのDeBERTaのものを使い、decoderはRoBERTaのものを使いました。
結果は次のようにZINC-t5の方が事前学習なしのt5-baseよりもlossが低くなっており、事前学習を行うことでより高い精度で反応物を予測できていることがわかります。
original-t5-finetuning-loss.png
zinc-t5-finetuning-loss.png

まとめ

今回はT5,DeBERTaを有名な化合物データセットであるZINC、PubChemで事前学習させ、化合物構造式の線形表記法のひとつであるSMILESに特化した言語モデルを作成しました。そして、最後に化学反応における生成物を入力すると反応物を出力するというSeq2Seqのタスクに適用すると、事前学習なしの場合に比べてよい性能を示すことも確認しました。また、これらのモデルはHugging Faceのtransformerをベースにしているため、trainerのAPIを用いれば数行のコードで様々なタスクに適用できるようになっています。SMILESを扱う機会があれば、ぜひ利用してみてください。

(この記事は研究室インターンで取り組みました:https://kojima-r.github.io/kojima/)

参考文献

3
3
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
3

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?