LoginSignup
1
2

More than 1 year has passed since last update.

Huggin Faceのtransformersを使用して化学反応の予測をしてみた

Last updated at Posted at 2023-01-26

はじめに

この記事では、前回の記事で作成した事前学習モデルを用いて化学反応の収率予測、生成物予測を行います。使用したデータセットはここにあります。
コードの詳細については、githubを参照してください。
環境構築なしにすぐに使いたいという方は、Hugging Faceのspaceでデモを作成したため、そちらをご覧ください。(生成物予測収率予測)

目次

  1. finetuning済みのモデルを使ってみる
  2. データの前処理
  3. 生成物の予測
  4. 収率の予測
  5. まとめ
  6. 参考文献

finetuning済みのモデルを使ってみる

まずは生成物の予測です。
Hugging Face HubにORD(Open Reaction Database)でfinetuningしたモデルがアップロードされているため、それをロードすることで簡単に使うことができます。

生成物の予測
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = AutoModelForSeq2SeqLM.from_pretrained(CFG.model_name_or_path).to(device)
tokenizer = AutoTokenizer.from_pretrained(CFG.model_name_or_path, return_tensors='pt')

min_length = 41
input = "REACTANT:CNc1nc(SC)ncc1CO.O.O=[Cr](=O)([O-])O[Cr](=O)(=O)[O-].[Na+]CATALYST: REAGENT: SOLVENT:CC(=O)O"
inp = tokenizer(input, return_tensors='pt').to(device)
output = model.generate(**inp, min_length=min_length, max_length=min_length+50, num_beams=5, num_return_sequences=1)
output = [tokenizer.decode(i, skip_special_tokens=True).replace('. ', '.').rstrip('.') for i in output][0]
print(output) # CC(C)CN=C1C(c2ccccc2)=C(c2ccccc2)C(c2ccccc2)=C1c1ccccc1

入力は"REACTANT:{反応物のsmiles}CATALYST:{触媒のsmiles}REAGENT:{溶媒のsmiles}"を想定しています。もし触媒や溶媒の情報がない場合は半角スペースを代わりに入れます。
また、入力は正規化されていることを想定しています。RDKitを用いた正規化の方法は前回の記事を参照してください。

次は収率の予測です。
収率予想モデルは次のようにダウンロードする必要があります。

収率の予測
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, AutoConfig, AutoModel, T5EncoderModel, T5ForConditionalGeneration
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from tqdm import tqdm
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class CFG:
    model = 't5'
    fc_dropout = 0.1
    max_len = 512
    download_pretrained_model = True

if CFG.download_pretrained_model:
    os.mkdir('tokenizer')
    subprocess.run('wget https://huggingface.co/spaces/sagawa/predictyield-t5/resolve/main/ZINC-t5_best.pth', shell=True)
    subprocess.run('wget https://huggingface.co/spaces/sagawa/predictyield-t5/resolve/main/config.pth', shell=True)
    subprocess.run('wget https://huggingface.co/spaces/sagawa/predictyield-t5/raw/main/special_tokens_map.json -P ./tokenizer', shell=True)
    subprocess.run('wget https://huggingface.co/spaces/sagawa/predictyield-t5/raw/main/tokenizer.json -P ./tokenizer', shell=True)
    subprocess.run('wget https://huggingface.co/spaces/sagawa/predictyield-t5/raw/main/tokenizer_config.json -P ./tokenizer', shell=True)
    CFG.model_name_or_path = '.'

CFG.tokenizer = AutoTokenizer.from_pretrained(CFG.model_name_or_path+'/tokenizer', return_tensors='pt')

class RegressionModel(nn.Module):
    def __init__(self, cfg, config_path=None, pretrained=False):
        super().__init__()
        self.cfg = cfg
        if config_path is None:
            self.config = AutoConfig.from_pretrained(cfg.pretrained_model_name_or_path, output_hidden_states=True)
        else:
            self.config = torch.load(config_path)
        if pretrained:
            if 't5' in cfg.model:
                self.model = T5ForConditionalGeneration.from_pretrained(CFG.pretrained_model_name_or_path)
            else:
                self.model = AutoModel.from_pretrained(CFG.pretrained_model_name_or_path)
        else:
            if 't5' in cfg.model:
                self.model = T5ForConditionalGeneration.from_pretrained('sagawa/ZINC-t5')
            else:
                self.model = AutoModel.from_config(self.config)
        self.model.resize_token_embeddings(len(cfg.tokenizer))

        self.fc1 = nn.Linear(self.config.hidden_size, self.config.hidden_size//2)
        self.fc2 = nn.Linear(self.config.hidden_size, self.config.hidden_size//2)
        self.fc3 = nn.Linear(self.config.hidden_size//2*2, self.config.hidden_size)
        self.fc4 = nn.Linear(self.config.hidden_size, self.config.hidden_size)
        self.fc5 = nn.Linear(self.config.hidden_size, 1)

        self._init_weights(self.fc1)
        self._init_weights(self.fc2)
        self._init_weights(self.fc3)
        self._init_weights(self.fc4)
        self._init_weights(self.fc5)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.01)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=0.01)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        
    def forward(self, inputs):
        encoder_outputs = self.model.encoder(**inputs)
        encoder_hidden_states = encoder_outputs[0]
        outputs = self.model.decoder(input_ids=torch.full((inputs['input_ids'].size(0),1),
                                            self.config.decoder_start_token_id,
                                            dtype=torch.long,
                                            device=device), encoder_hidden_states=encoder_hidden_states)
        last_hidden_states = outputs[0]
        output1 = self.fc1(last_hidden_states.view(-1, self.config.hidden_size))
        output2 = self.fc2(encoder_hidden_states[:, 0, :].view(-1, self.config.hidden_size))
        output = self.fc3(torch.hstack((output1, output2)))
        output = self.fc4(output)
        output = self.fc5(output)
        return output

def prepare_input(cfg, text):
    inputs = cfg.tokenizer(text, add_special_tokens=True, max_length=CFG.max_len, padding='max_length', return_offsets_mapping=False, truncation=True, return_attention_mask=True)
    for k, v in inputs.items():
        inputs[k] = torch.tensor(v, dtype=torch.long)
    
    return inputs

class TestDataset(Dataset):
    def __init__(self, cfg, df):
        self.cfg = cfg
        self.inputs = df['input'].values
        
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, item):
        inputs = prepare_input(self.cfg, self.inputs[item])
        
        return inputs
    
def inference_fn(test_loader, model, device):
    preds = []
    model.eval()
    model.to(device)
    tk0 = tqdm(test_loader, total=len(test_loader))
    for inputs in tk0:
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        with torch.no_grad():
            y_preds = model(inputs)
        preds.append(y_preds.to('cpu').numpy())
    predictions = np.concatenate(preds)
    return predictions

model = RegressionModel(CFG, config_path=CFG.model_name_or_path + '/config.pth', pretrained=False)
state = torch.load(CFG.model_name_or_path + '/ZINC-t5_best.pth', map_location=torch.device('cpu'))
model.load_state_dict(state)

input = "REACTANT:CC(C)n1ncnc1-c1cn2c(n1)-c1cnc(O)cc1OCC2.CCN(C(C)C)C(C)C.Cl.NC(=O)[C@@H]1C[C@H](F)CN1REAGENT: PRODUCT:O=C(NNC(=O)C(F)(F)F)C(F)(F)F)"
test_ds = pd.DataFrame.from_dict({'input': input}, orient='index').T
test_dataset = TestDataset(CFG, test_ds)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=1, pin_memory=True, drop_last=False)

prediction = inference_fn(test_loader, model, device)
print(prediction[0][0]*100) # 19.27478015422821

収率の予測では事前学習済みモデルのencoder部分だけを使っています。
入力は"REACTANT:{反応物のsmiles}REAGENT:{溶媒、触媒のsmiles}PRODUCT:{生成物のsmiles}"を想定しています。

次に、どのようにfinetuningを行ったかを詳しく説明していきます。

データの前処理

ORDのデータセットをダウンロードし、次のようにカテゴリーごとに分離してからatom_mappingの除去、SMILESの正規化を行いました。

データの前処理
import os
import random
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
from rdkit import RDLogger, Chem
import math
RDLogger.DisableLog('rdApp.*')   

class CFG():
    data='all_ord_reaction_uniq_with_attr_v1.tsv'
    seed = 42


def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
seed_everything(seed=CFG.seed)  
    
df = pd.read_csv(CFG.data, sep='\t', names=['id', 'input', 'product', 'condition'])


def data_split(row):
    dic = {'CATALYST': [], 'REACTANT': [], 'REAGENT': [], 'SOLVENT': [], 'INTERNAL_STANDARD': [], 'NoData': []}
    inp_cat = ['CATALYST', 'REACTANT', 'REAGENT', 'SOLVENT', 'INTERNAL_STANDARD', 'NoData']
    inp = row['input']
    if type(inp) == str:
        for item in inp.split('.'):
            for cat in inp_cat:
                if cat in item:
                    dic[cat].append(item[item.find(':')+1:])
                    break
    for k, v in dic.items():
        dic[k] = '.'.join(dic[k])

    pro = row['product']
    if type(pro) == str:
        pro = pro.replace('.PRODUCT', 'PRODUCT')
        pro_lis = []
        for item in pro.split('PRODUCT:'):
            if item != '':
                pro_lis.append(item)
        dic['PRODUCT'] = '.'.join(pro_lis)
    else:
        dic['PRODUCT'] = None
    
    con = row['condition']
    if type(con) == str:
        if 'YIELD' in con and 'TEMP' in con:
            pos = con.find('.T')
            for item, cat in zip([con[:pos], con[pos:]], ['YIELD', 'TEMP']):   
                dic[cat] = float(item[item.find(':')+1:])
        elif 'YIELD' in con:
            dic['YIELD'] = float(con[con.find(':')+1:])
            dic['TEMP'] = None
        elif 'TEMP' in con:
            dic['YIELD'] = None
            dic['TEMP'] = float(con[con.find(':')+1:])
        else:
            print(con)
    else:
        for cat in ['YIELD', 'TEMP']:
            dic[cat] = None
    return list(dic.values())

dic = {'CATALYST': [], 'REACTANT': [], 'REAGENT': [], 'SOLVENT': [], 'INTERNAL_STANDARD': [], 'NoData': [], 'PRODUCT': [],'YIELD': [], 'TEMP': []}
cat = ['CATALYST', 'REACTANT', 'REAGENT', 'SOLVENT', 'INTERNAL_STANDARD', 'NoData','PRODUCT', 'YIELD', 'TEMP']
for idx, row in df.iterrows():
    lst = data_split(row)
    for i in range(len(lst)):
        dic[cat[i]].append(lst[i]) 
cleaned_df = pd.DataFrame(dic)

def remove_atom_mapping(smi):
    mol = Chem.MolFromSmiles(smi)
    [a.SetAtomMapNum(0) for a in mol.GetAtoms()]
    smi = Chem.MolToSmiles(mol)
    return canonicalize(smi)

def canonicalize(smi):
    smi = Chem.MolToSmiles(Chem.MolFromSmiles(smi),True)
    return smi

cleaned_df['CATALYST'] = cleaned_df['CATALYST'].apply(lambda x: remove_atom_mapping(x) if type(x) == str else None)
cleaned_df['REACTANT'] = cleaned_df['REACTANT'].apply(lambda x: remove_atom_mapping(x) if type(x) == str else None)
cleaned_df['REAGENT'] = cleaned_df['REAGENT'].apply(lambda x: remove_atom_mapping(x) if type(x) == str else None)
cleaned_df['SOLVENT'] = cleaned_df['SOLVENT'].apply(lambda x: remove_atom_mapping(x) if type(x) == str else None)
cleaned_df['INTERNAL_STANDARD'] = cleaned_df['INTERNAL_STANDARD'].apply(lambda x: remove_atom_mapping(x) if type(x) == str else None)
cleaned_df['NoData'] = cleaned_df['NoData'].apply(lambda x: remove_atom_mapping(x) if type(x) == str else None)
cleaned_df['PRODUCT'] = cleaned_df['PRODUCT'].apply(lambda x: remove_atom_mapping(x) if type(x) == str else None)

cleaned_df.to_csv('all_ord_reaction_uniq_with_attr_v3.csv', index=False)

生成物の予測

ORDのデータを使って化学反応の生成物予測モデルを学習します。

生成物の予測
import os
import gc
import random
import itertools
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import torch
import tokenizers
import transformers
from transformers import AutoTokenizer, EncoderDecoderModel, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
import datasets
from datasets import load_dataset, load_metric, Dataset, DatasetDict
import sentencepiece
import argparse
from sklearn.model_selection import train_test_split
from datasets.utils.logging import disable_progress_bar
disable_progress_bar()

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str, required=False)
    parser.add_argument("--dataset_name", type=str, required=False)
    parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--debug", action='store_true', default=False, required=False)
    parser.add_argument("--epochs", type=int, default=3, required=False)
    parser.add_argument("--lr", type=float, default=2e-5, required=False)
    parser.add_argument("--batch_size", type=int, default=16, required=False)
    parser.add_argument("--input_max_len", type=int, default=128, required=False)
    parser.add_argument("--target_max_len", type=int, default=128, required=False)
    parser.add_argument("--weight_decay", type=float, default=0.01, required=False)
    parser.add_argument("--evaluation_strategy", type=str, default="epoch", required=False)
    parser.add_argument("--save_strategy", type=str, default="epoch", required=False)
    parser.add_argument("--logging_strategy", type=str, default="epoch", required=False)
    parser.add_argument("--save_total_limit", type=int, default=2, required=False)
    parser.add_argument("--fp16", action='store_true', default=False, required=False)
    parser.add_argument("--disable_tqdm", action="store_true", default=False, required=False)
    parser.add_argument("--multitask", action="store_true", default=False, required=False)
    parser.add_argument("--seed", type=int, default=42, required=False)

    return parser.parse_args()
    
CFG = parse_args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(seed=CFG.seed)  
    

df = pd.read_csv(CFG.data_path)
df = df[~df['PRODUCT'].isna()]
for col in ['CATALYST', 'REACTANT', 'REAGENT', 'SOLVENT', 'INTERNAL_STANDARD', 'NoData','PRODUCT', 'YIELD', 'TEMP']:
    df[col] = df[col].fillna(' ')
df['TEMP'] = df['TEMP'].apply(lambda x: str(x))


df = df[df['REACTANT'] != ' ']
df = df[['REACTANT', 'PRODUCT', 'CATALYST', 'REAGENT', 'SOLVENT']].drop_duplicates().reset_index(drop=True)
df = df.iloc[df[['REACTANT', 'CATALYST', 'REAGENT', 'SOLVENT']].drop_duplicates().index].reset_index(drop=True)


def clean(row):
    row = row.replace('. ', '').replace(' .', '').replace('  ', ' ')
    return row
df['REAGENT'] = df['CATALYST'] + '.' + df['REAGENT'] + '.' + df['SOLVENT']
df['REAGENT'] = df['REAGENT'].apply(lambda x: clean(x))
from rdkit import Chem
def canonicalize(mol):
    mol = Chem.MolToSmiles(Chem.MolFromSmiles(mol),True)
    return mol
df['REAGENT'] = df['REAGENT'].apply(lambda x: canonicalize(x) if x != ' ' else ' ')

df['input'] = 'REACTANT:' + df['REACTANT'] + 'REAGENT:' + df['REAGENT']

lens = df['input'].apply(lambda x: len(x))
df = df[lens <= 512]

train, test = train_test_split(df, test_size=int(len(df)*0.1))
train, valid = train_test_split(train, test_size=int(len(df)*0.1))


if CFG.debug:
    train = train[:int(len(train)/400)].reset_index(drop=True)
    valid = valid[:int(len(valid)/40)].reset_index(drop=True)
    
    
train[['input', 'PRODUCT']].to_csv('multi-input-train.csv', index=False)
valid[['input', 'PRODUCT']].to_csv('multi-input-valid.csv', index=False)
test[['input', 'PRODUCT']].to_csv('multi-input-test.csv', index=False)

nodata = pd.read_csv('/data2/sagawa/transformer-chemical-reaction-prediciton/compound-classification/reconstructed.csv')
nodata = nodata[~nodata['REACTANT'].isna()]
for col in ['REAGENT']:
    nodata[col] = nodata[col].fillna(' ')
nodata['input'] = 'REACTANT:' + nodata['REACTANT'] + 'REAGENT:' + nodata['REAGENT']
train = pd.concat([train[['input', 'PRODUCT']], nodata[['input', 'PRODUCT']]]).reset_index(drop=True)

dataset = DatasetDict({'train': Dataset.from_pandas(train[['input', 'PRODUCT']]), 'validation': Dataset.from_pandas(valid[['input', 'PRODUCT']])})


def preprocess_function(examples):
    inputs = examples['input']
    targets = examples['PRODUCT']
    model_inputs = tokenizer(inputs, max_length=CFG.input_max_len, truncation=True)
    labels = tokenizer(targets, max_length=CFG.target_max_len, truncation=True)
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

def compute_metrics(eval_preds):
    metric = load_metric('sacrebleu')
    preds, labels = eval_preds
    if isinstance(preds, tuple):
        preds = preds[0]

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [[label.strip()] for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    return {'bleu': result['score']}


#load tokenizer
try: # load pretrained tokenizer from local directory
    tokenizer = AutoTokenizer.from_pretrained(os.path.abspath(CFG.pretrained_model_name_or_path), return_tensors='pt')
except: # load pretrained tokenizer from huggingface model hub
    tokenizer = AutoTokenizer.from_pretrained(CFG.pretrained_model_name_or_path, return_tensors='pt')
tokenizer.add_tokens('.')
tokenizer.add_special_tokens({'additional_special_tokens': tokenizer.additional_special_tokens + ['REACTANT:', 'REAGENT:']})


#load model
if CFG.model == 't5':
    try: # load pretrained model from local directory
        model = AutoModelForSeq2SeqLM.from_pretrained(os.path.abspath(CFG.pretrained_model_name_or_path), from_flax=True)
    except: # load pretrained model from huggingface model hub
        model = AutoModelForSeq2SeqLM.from_pretrained(CFG.pretrained_model_name_or_path, from_flax=True)
    model.resize_token_embeddings(len(tokenizer))
elif CFG.model == 'deberta':
    try: # load pretrained model from local directory
        model = EncoderDecoderModel.from_encoder_decoder_pretrained(os.path.abspath(CFG.pretrained_model_name_or_path), 'roberta-large')
    except: # load pretrained model from huggingface model hub
        model = EncoderDecoderModel.from_encoder_decoder_pretrained(CFG.pretrained_model_name_or_path, 'roberta-large')
    model.encoder.resize_token_embeddings(len(tokenizer))
    model.decoder.resize_token_embeddings(len(tokenizer))
    config_encoder = model.config.encoder
    config_decoder = model.config.decoder
    config_decoder.is_decoder = True
    config_decoder.add_cross_attention = True
    model.config.decoder_start_token_id = tokenizer.bos_token_id
    model.config.pad_token_id = tokenizer.pad_token_id

tokenized_datasets = dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=dataset['train'].column_names,
    load_from_cache_file=False
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

args = Seq2SeqTrainingArguments(
    CFG.model,
    evaluation_strategy=CFG.evaluation_strategy,
    save_strategy=CFG.save_strategy,
    learning_rate=CFG.lr,
    per_device_train_batch_size=CFG.batch_size,
    per_device_eval_batch_size=CFG.batch_size,
    weight_decay=CFG.weight_decay,
    save_total_limit=CFG.save_total_limit,
    num_train_epochs=CFG.epochs,
    predict_with_generate=True,
    fp16=CFG.fp16,
    disable_tqdm=CFG.disable_tqdm,
    push_to_hub=False,
    load_best_model_at_end=True
)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

trainer.train()

trainer.save_model('./best_model')

DeBERTaはencoderモデルで、そのdecoderはHuggingFaceで公開されていません。そのため、事前学習をしたDeBERTaのencoderにRoBERTaのdecoderをつけることで対処しました。

収率の予測

ORDのデータを使って化学反応の収率予測モデルを学習します。

収率の予測
import os
import gc
import random
import itertools
import warnings
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
import tokenizers
import transformers
from transformers import AutoTokenizer, AutoConfig, AutoModel, T5EncoderModel, get_linear_schedule_with_warmup, T5ForConditionalGeneration
import datasets
from datasets import load_dataset, load_metric
import sentencepiece
import argparse
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
import torch.nn as nn
from torch.optim import AdamW
import pickle
import time
import math
from sklearn.preprocessing import MinMaxScaler
from datasets.utils.logging import disable_progress_bar
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
disable_progress_bar()
os.environ['TOKENIZERS_PARALLELISM']='false'

def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_path", type=str, required=False)
    parser.add_argument("--model", type=str, default='t5', required=False)
    parser.add_argument("--pretrained_model_name_or_path", type=str, required=True)
    parser.add_argument("--model_name_or_path", type=str, required=False)
    parser.add_argument("--debug", action='store_true', default=False, required=False)
    parser.add_argument("--epochs", type=int, default=5, required=False)
    parser.add_argument("--patience", type=int, default=10, required=False)
    parser.add_argument("--lr", type=float, default=5e-4, required=False)
    parser.add_argument("--batch_size", type=int, default=5, required=False)
    parser.add_argument("--max_len", type=int, default=512, required=False)
    parser.add_argument("--num_workers", type=int, default=1, required=False)
    parser.add_argument("--fc_dropout", type=float, default=0.0, required=False)
    parser.add_argument("--eps", type=float, default=1e-6, required=False)
    parser.add_argument("--max_grad_norm", type=int, default=1000, required=False)
    parser.add_argument("--gradient_accumulation_steps", type=int, default=1, required=False)
    parser.add_argument("--num_warmup_steps", type=int, default=0, required=False)
    parser.add_argument("--batch_scheduler", action='store_true', default=False, required=False)
    parser.add_argument("--print_freq", type=int, default=100, required=False)
    parser.add_argument("--use_apex", action='store_true', default=False, required=False)
    parser.add_argument("--output_dir", type=str, default='./', required=False)
    parser.add_argument("--weight_decay", type=float, default=0.01, required=False)
    parser.add_argument("--seed", type=int, default=42, required=False)

    return parser.parse_args()

CFG = parse_args()


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

OUTPUT_DIR = CFG.output_dir
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
seed_everything(seed=CFG.seed)  
    

df = pd.read_csv(CFG.data_path).drop_duplicates().reset_index(drop=True)
df = df[~df['YIELD'].isna()].reset_index(drop=True)
df = df[~(df['YIELD']>100)].reset_index(drop=True)
df['YIELD'] = df['YIELD']/100
df = df[~(df['REACTANT'].isna() | df['PRODUCT'].isna())]
for col in ['CATALYST', 'REACTANT', 'REAGENT', 'SOLVENT', 'INTERNAL_STANDARD', 'NoData','PRODUCT']:
    df[col] = df[col].fillna(' ')
    

def clean(row):
    row = row.replace('. ', '').replace(' .', '').replace('  ', ' ')
    return row
df['REAGENT'] = df['CATALYST'] + '.' + df['REAGENT']
df['REAGENT'] = df['REAGENT'].apply(lambda x: clean(x))

from rdkit import Chem
def canonicalize(mol):
    mol = Chem.MolToSmiles(Chem.MolFromSmiles(mol),True)
    return mol

df['REACTANT'] = df['REACTANT'].apply(lambda x: canonicalize(x) if x != ' ' else ' ')
df['REAGENT'] = df['REAGENT'].apply(lambda x: canonicalize(x) if x != ' ' else ' ')
df['PRODUCT'] = df['PRODUCT'].apply(lambda x: canonicalize(x) if x != ' ' else ' ')


df['input'] = 'REACTANT:' + df['REACTANT']  + 'REAGENT:' + df['REAGENT'] + 'PRODUCT:' + df['PRODUCT']
df = df[['input', 'YIELD']].drop_duplicates().reset_index(drop=True)


lens = df['input'].apply(lambda x: len(x))
df = df[lens <= 512].reset_index(drop=True)

train_ds, test_ds = train_test_split(df, test_size=int(len(df)*0.1))
train_ds, valid_ds = train_test_split(train_ds, test_size=int(len(df)*0.1))
train_ds.to_csv('regression-input-train.csv', index=False)
valid_ds.to_csv('regression-input-valid.csv', index=False)
test_ds.to_csv('regression-input-test.csv', index=False)

if CFG.debug:
    train_ds = train_ds[:int(len(train_ds)/4)].reset_index(drop=True)
    valid_ds = valid_ds[:int(len(valid_ds)/4)].reset_index(drop=True)
        
    
def get_logger(filename=OUTPUT_DIR+'train'):
    from logging import getLogger, INFO, StreamHandler, FileHandler, Formatter
    logger = getLogger(__name__)
    logger.setLevel(INFO)
    handler1 = StreamHandler()
    handler1.setFormatter(Formatter("%(message)s"))
    handler2 = FileHandler(filename=f"{filename}.log")
    handler2.setFormatter(Formatter("%(message)s"))
    logger.addHandler(handler1)
    logger.addHandler(handler2)
    return logger

LOGGER = get_logger()

#load tokenizer
try: # load pretrained tokenizer from local directory
    tokenizer = AutoTokenizer.from_pretrained(os.path.abspath(CFG.pretrained_model_name_or_path), return_tensors='pt')
except: # load pretrained tokenizer from huggingface model hub
    tokenizer = AutoTokenizer.from_pretrained(CFG.pretrained_model_name_or_path, return_tensors='pt')
tokenizer.add_tokens(['.', 'P', '>', '<','Pd'])

tokenizer.add_special_tokens({'additional_special_tokens': tokenizer.additional_special_tokens + ['REACTANT:', 'PRODUCT:', 'REAGENT:']})
tokenizer.save_pretrained(OUTPUT_DIR+'tokenizer/')
CFG.tokenizer = tokenizer

def prepare_input(cfg, text):
    inputs = cfg.tokenizer(text, add_special_tokens=True, max_length=CFG.max_len, padding='max_length', return_offsets_mapping=False, truncation=True, return_attention_mask=True)
    for k, v in inputs.items():
        inputs[k] = torch.tensor(v, dtype=torch.long)
    
    return inputs


class TrainDataset(Dataset):
    def __init__(self, cfg, df):
        self.cfg = cfg
        self.inputs = df['input'].values
        self.labels = df['YIELD'].values
        
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, item):
        inputs = prepare_input(self.cfg, self.inputs[item])
        label = torch.tensor(self.labels[item], dtype=torch.float)
        
        return inputs, label
    
       
class RegressionModel(nn.Module):
    def __init__(self, cfg, config_path=None, pretrained=False):
        super().__init__()
        self.cfg = cfg
        if config_path is None:
            self.config = AutoConfig.from_pretrained(cfg.pretrained_model_name_or_path, output_hidden_states=True)
        else:
            self.config = torch.load(config_path)
        if pretrained:
            if 't5' in cfg.model:
                self.model = T5ForConditionalGeneration.from_pretrained(CFG.pretrained_model_name_or_path)
            else:
                self.model = AutoModel.from_pretrained(CFG.pretrained_model_name_or_path)
        else:
            if 't5' in cfg.model:
                self.model = T5ForConditionalGeneration.from_pretrained('sagawa/ZINC-t5')
            else:
                self.model = AutoModel.from_config(self.config)
        self.model.resize_token_embeddings(len(cfg.tokenizer))
        self.fc_dropout1 = nn.Dropout(cfg.fc_dropout)
        self.fc1 = nn.Linear(self.config.hidden_size, self.config.hidden_size//2)
        self.fc_dropout2 = nn.Dropout(cfg.fc_dropout)
        
        self.fc2 = nn.Linear(self.config.hidden_size, self.config.hidden_size//2)
        self.fc3 = nn.Linear(self.config.hidden_size//2*2, self.config.hidden_size)
        self.fc4 = nn.Linear(self.config.hidden_size, self.config.hidden_size)
        self.fc5 = nn.Linear(self.config.hidden_size, 1)

        self._init_weights(self.fc1)
        self._init_weights(self.fc2)
        self._init_weights(self.fc3)
        self._init_weights(self.fc4)
        self._init_weights(self.fc5)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=0.01)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=0.01)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()
        elif isinstance(module, nn.LayerNorm):
            module.bias.data.zero_()
            module.weight.data.fill_(1.0)
        
    def forward(self, inputs):
        encoder_outputs = self.model.encoder(**inputs)
        encoder_hidden_states = encoder_outputs[0]
        outputs = self.model.decoder(input_ids=torch.full((inputs['input_ids'].size(0),1),
                                            self.config.decoder_start_token_id,
                                            dtype=torch.long,
                                            device=device), encoder_hidden_states=encoder_hidden_states)
        last_hidden_states = outputs[0]
        output1 = self.fc1(self.fc_dropout1(last_hidden_states).view(-1, self.config.hidden_size))
        output2 = self.fc2(encoder_hidden_states[:, 0, :].view(-1, self.config.hidden_size))
        output = self.fc3(self.fc_dropout2(torch.hstack((output1, output2))))
        output = self.fc4(output)
        output = self.fc5(output)
        return output
    
    
class AverageMeter(object):
    def __init__(self):
        self.reset()
        
    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0
        
    def update(self, val, n=1):
        self.val = val
        self.sum += val*n
        self.count += n
        self.avg = self.sum/self.count
        

def asMinutes(s):
    m = math.floor(s/60)
    s -= m*60
    return '%dm %ds' % (m, s)


def timeSince(since, percent):
    now = time.time()
    s = now - since
    es = s/(percent)
    rs = es - s
    return '%s (remain %s)' % (asMinutes(s), asMinutes(rs))


def train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device):
    model.train()
    scaler = torch.cuda.amp.GradScaler(enabled=CFG.use_apex)
    losses = AverageMeter()
    start = end = time.time()
    global_step = 0
    for step, (inputs, labels) in enumerate(train_loader):
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        labels = labels.to(device)
        batch_size = labels.size(0)
        with torch.cuda.amp.autocast(enabled=CFG.use_apex):
            y_preds = model(inputs)
        loss = criterion(y_preds.view(-1, 1), labels.view(-1, 1))
        if CFG.gradient_accumulation_steps > 1:
            loss = loss/CFG.gradient_accumulation_steps
        losses.update(loss.item(), batch_size)
        scaler.scale(loss).backward()
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), CFG.max_grad_norm)
        if (step + 1) % CFG.gradient_accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            global_step += 1
            if CFG.batch_scheduler:
                scheduler.step()
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(train_loader)-1):
            print('Epoch: [{0}][{1}/{2}] '
                  'Elapsed {remain:s} '
                  'Loss: {loss.val:.4f}({loss.avg:.4f}) '
                  'Grad: {grad_norm:.4f}  '
                  'LR: {lr:.8f}  '
                  .format(epoch+1, step, len(train_loader), 
                          remain=timeSince(start, float(step+1)/len(train_loader)),
                          loss=losses,
                          grad_norm=grad_norm,
                          lr=scheduler.get_lr()[0]), flush=True)
    return losses.avg


def valid_fn(valid_loader, model, criterion, device):
    losses = AverageMeter()
    model.eval()
    start = end = time.time()
    label_list = []
    pred_list = []
    for step, (inputs, labels) in enumerate(valid_loader):
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        with torch.no_grad():
            y_preds = model(inputs)
        label_list += labels.tolist()
        pred_list += y_preds.tolist()
        end = time.time()
        if step % CFG.print_freq == 0 or step == (len(valid_loader)-1):
            print('EVAL: [{0}/{1}] '
                  'Elapsed {remain:s} '
                  'RMSE Loss: {loss:.4f} '
                  'r2 score: {r2_score:.4f} '
                  .format(step, len(valid_loader),
                          loss=mean_squared_error(label_list, pred_list, squared=False),
                          remain=timeSince(start, float(step+1)/len(valid_loader)),
                          r2_score=r2_score(label_list, pred_list)))
    return mean_squared_error(label_list, pred_list), r2_score(label_list, pred_list)
    
def inference_fn(test_loader, model, device):
    preds = []
    model.eval()
    model.to(device)
    tk0 = tqdm(test_loader, total=len(test_loader))
    for inputs in tk0:
        for k, v in inputs.items():
            inputs[k] = v.to(device)
        with torch.no_grad():
            y_preds = model(inputs)
        preds.append(y_preds.to('cpu').numpy())
    predictions = np.concatenate(preds)
    return predictions


def train_loop(train_ds, valid_ds):
    
    train_dataset = TrainDataset(CFG, train_ds)
    valid_dataset = TrainDataset(CFG, valid_ds)
    valid_labels = valid_ds['YIELD'].values
    
    train_loader = DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers, pin_memory=True, drop_last=True)
    valid_loader = DataLoader(valid_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers, pin_memory=True, drop_last=False)
    
    model = RegressionModel(CFG, config_path=None, pretrained=True)
    torch.save(model.config, OUTPUT_DIR+'config.pth')
    model.to(device)
    
    def get_optimizer_params(model, encoder_lr, decoder_lr, weight_decay=0.0):
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_parameters = [
            {'params': [p for n, p in model.model.named_parameters() if not any(nd in n for nd in no_decay)], 'lr': encoder_lr, 'weight_decay': weight_decay},
            {'params': [p for n, p in model.model.named_parameters() if any(nd in n for nd in no_decay)], 'lr': encoder_lr, 'weight_decay': 0.0},
            {'params': [p for n, p in model.named_parameters() if 'model' not in n], 'lr': decoder_lr, 'weight_decay': 0.0}
        ]
        return optimizer_parameters
    
    optimizer_parameters = get_optimizer_params(model, encoder_lr=CFG.lr, decoder_lr=CFG.lr, weight_decay=CFG.weight_decay)
    optimizer = AdamW(optimizer_parameters, lr=CFG.lr, eps=CFG.eps, betas=(0.9, 0.999))
    
    num_train_steps = int(len(train_ds)/CFG.batch_size*CFG.epochs)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=CFG.num_warmup_steps, num_training_steps=num_train_steps)
    
    criterion = nn.MSELoss(reduction='mean')
    best_loss = float('inf')
    es_count = 0
    
    for epoch in range(CFG.epochs):
        start_time = time.time()

        avg_loss = train_fn(train_loader, model, criterion, optimizer, epoch, scheduler, device)
        val_loss, val_r2_score = valid_fn(valid_loader, model, criterion, device)
        
        elapsed = time.time() - start_time

        LOGGER.info(f'Epoch {epoch+1} - avg_train_loss: {avg_loss:.4f}  val_rmse_loss: {val_loss:.4f}  val_r2_score: {val_r2_score:.4f}  time: {elapsed:.0f}s')
    
        if val_loss < best_loss:
            es_count = 0
            best_loss = val_loss
            LOGGER.info(f'Epoch {epoch+1} - Save Lowest Loss: {best_loss:.4f} Model')
            torch.save(model.state_dict(), OUTPUT_DIR+f"{CFG.pretrained_model_name_or_path.split('/')[-1]}_best.pth")
            
        else:
            es_count += 1
            if es_count >= CFG.patience:
                print('early_stopping')
                break
    
    torch.cuda.empty_cache()
    gc.collect()

            
if __name__ == '__main__':
    train_loop(train_ds, valid_ds)
        
 

まとめ

今回は事前学習済みのT5、DeBERTaのfinetuningによって化学反応の生成物、収率の予測を行いました。これら以外にも逆反応予測や安定性予測、活性予測などSMILESを使った様々なタスクに利用できることが期待されますので、ぜひ利用してみてください。

(この記事は研究室インターンで取り組みました:https://kojima-r.github.io/kojima/)

参考文献

1
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
2