2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

分子生成モデルの実装入門:SELFIES×VAEで新規分子を作る

Last updated at Posted at 2025-09-26

生成モデル(VAE)を使って新規分子を作るまで

「ニューラルネットワークで分子を生成できるの?🐱」
→はい、できます🐵

分子は、例えば「c1ccccc1」のようにSMILES という文字列形式で表すことができます。
このSMILESをトークンに分割し、時系列データとしてニューラルネットワークに入力すると、自然言語処理のように学習が可能になります。

今回紹介するのは VAE(Variational AutoEncoder) を使った分子生成です。
うまく学習をさせることができれば、元の分子に似ているが異なる分子を生成することができます

イメージはこんな感じです👇(MI-6さんの記事から引用)
image.png
図. VAEの構造および、潜在空間を用いた分子最適化のイメージ (出典: https://milab.mi-6.co.jp/article/t0028)

VAEでの流れ:alien:

  1. エンコーダ:入力SMILESをトークン化し、潜在空間に圧縮
  2. 潜在空間:ベクトルにノイズを加えることで「揺らぎ」を作る
  3. デコーダ:潜在ベクトルから再びSMILESを生成

この仕組みにより、既存分子と似ているが新しい分子を生成することができます。
それではやっていきましょう。

イントロでは SMILES と記載していますが、実際に 有効な分子を安定して生成できたのは SELFIES でした。

SELFIES(Self-Referencing Embedded Strings)は SMILES を改良した分子表記法で、生成された文字列が化学的に妥当な分子に対応するという特徴があります。
これにより、学習済みモデルから「壊れた構造」が出てくることを防ぐことができます。

GPU使用推奨です:bow_tone2: CPUのみでももちろん計算はできますが、時間かなりかかる可能性大です。持っていない方はGoogle Colabなどで無料GPU使いましょう。

一応、GPU付きPCのセットアップ記事も以前書いたので、興味ある方はこちらの記事見てみてください。
https://qiita.com/Osarunokagoya/items/761b0d731371ed50a30e

1. データセット読み込み&前処理

有機分子のPL波長が載ったこちらのデータセットを使います。特許と論文から波長を集めてきました。保存して、notebookと同じフォルダに入れてください。
https://drive.google.com/file/d/1JIJH0XqFydZKm8t4x1DoHDcFqIKFUir-/view?usp=sharing

必要なライブラリを読み込みます。

import pandas as pd
from tqdm import tqdm
import selfies as sf
import math
import matplotlib
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

from rdkit import rdBase
from rdkit import Chem
from rdkit.Chem import Draw
from IPython.display import display

それぞれのversionは以下の通りです。

pandas: 2.3.2
tqdm: 4.67.1
selfies: 2.1.1
torch: 2.8.0+cu128
rdkit: 2025.03.5
matplotlib: 3.9.4

データセットを読み込みます。有機分子の構造(SMILES)とPL波長(有機分子は光るのですが、その発光波長)がセットで載っています。学習データ(正解あり)とテストデータ(正解なし)が混在しているので、今回は学習データのみを抽出します。

# データセット読み込み
df = pd.read_csv('material_data.csv', index_col=0)

# PLに欠損値が入っている行を消して、学習データのみにする
dataset_train = df.dropna(subset='PL')

# データセットの確認
print(df.head())
                                                     SMILES Type     PL
Material                                                               
BTD1      CC1(C)C2=CC(C3=C(N=S=N4)C4=C(C5=CC=C(C6=CC=CC=...  BTD  533.0
BTD2      CC1(C)C2=CC(C3=C(N=S=N4)C4=C(C5=CC=C(C6=CC=C(C...  BTD  540.0
BTD3      CC1(C)C2=C(C=CC(C3=CC(N4C(C=CC=C5)=C5C6=C4C=CC...  BTD  538.0
BTD4      CC1(C)C2=CC(C3=C(N=S=N4)C4=C(C5=CC=C(C6=CC=C(N...  BTD  584.0
BTD5      CC1(C)C2=CC(C3=C(N=S=N4)C4=C(C5=CC=C(C6=CC=C(C...  BTD  540.0

そして、データセットの中からSMILESの列を抽出し、有効なSMILES列のみ取り出します。

# SMILES列を抽出、欠損値を除いてリスト化
smiles = dataset_train['SMILES'].dropna().tolist()

# 無効なSMILESを削除
valid_smiles = []
for s in smiles:
    try:
        if Chem.MolFromSmiles(s) is not None:
            valid_smiles.append(s)
    except:
        continue

print('有効なSMILESの数', len(valid_smiles))
有効なSMILESの数 251

Molファイルという化学的な情報の入ったファイルがあるのですが、読み込んだSMILESにてこちらが計算できるかどうかで、有効なSMILESかどうかを判断しています。

有効なSMILESを取得出来たら、次にSELFIESを計算します。改めてSELFIESの復習ですが、SELFIESは、分子を「トークン」に分解してエンコードする方式で、どんな文字列を並べても必ず有効な分子になることが保証されているそうです。

下記がSMILESとSELFIESの違いの例です。

SMILES:  CCO
SELFIES: [C][C][O]

空のリストを用意して、計算したSELFIESを入れていきましょう。

# 既存SMILESデータ -> SELFIESへ一括変換
selfies_list = []
for s in valid_smiles:
    try:
        selfies_list.append(sf.encoder(s))  # 失敗したら except でスキップ
    except Exception:
        pass

print('有効なSELFIESの数', len(selfies_list))
有効なSELFIESの数 251

251全てのSMILESが問題なくSELFIESとして計算できたことが分かります。

次にこのSELFIESをニューラルネットワークに食わせるため「辞書作り」を行います。自然言語処理で文章を単語やサブワードに分解しIDにするのと同じ流れです。

チャッピーに作ってもらったイメージはこんな感じです。
image.png

# アルファベット作成
alphabet = sf.get_alphabet_from_selfies(selfies_list)  # SELFIESで出てくる全種類のトークンを集める
alphabet = sorted(list(alphabet))
special = ["<pad>", "<start>", "<end>"]
tokens  = alphabet + special  # 特殊トークンを追加
print('トークンの数', len(tokens))

# 文字→ID、ID→文字の辞書を作成
stoi = {t:i for i,t in enumerate(tokens)}
itos = {i:t for t,i in stoi.items()}
pad_idx   = stoi["<pad>"]; start_idx = stoi["<start>"]; end_idx = stoi["<end>"]  # さらに特殊トークンのIDも控えておく
トークンの数 38

まずalphabetという変数を用意し、データセット中のSELFIESを全部見て、出現した トークン([C], [O], [=C], [Ring1] ... など) を集めます。

次に特殊トークンを追加します。これらは分子の形を整えるのに必要です。
pad … シーケンス長を揃えるための空白埋め
start … デコーダで「ここから生成開始するよ」という印
end … 「ここで分子が終わった」という終了の印

そして、辞書の作成を行います。
stoi : 文字列([C], [O], "" …)を整数IDに変換する辞書
itos : 整数IDを文字列に戻す辞書

最後に特殊トークンのIDを控えておきます(pad_idx etc.)、これらは学習や生成でよく使うため、変数にしておきます。

これらをPytorchで読み込むために追加の処理を行います。よく見るデータセット、データローダーの定義です。

# SELFIESを分割する関数
def selfies_tokens(x): # SELFIESをトークンのリストに分割、=> ["[C]", "[C]", "[O]"]
    return list(sf.split_selfies(x))

# エンコード関数(数値化+長さ揃え)
def encode_selfies(x, max_len=120):
    toks = ["<start>"] + selfies_tokens(x) + ["<end>"]  # 例: "C(C)O" → SELFIES → "[C][C][O]" → ["[C]", "[C]", "[O]"]、開始と終了を追加
    ids = [stoi[t] for t in toks if t in stoi]  # トークンを数字に変換、stoi は「string → index(数字)」辞書。例えば "[C]" → 7, "<start>" → 45 みたいに割り当てられる。
    ids = ids[:max_len] + [pad_idx] * max(0, max_len - len(ids))  # 長さを max_len に揃える
    return ids

# データセット&DataLoaderの作成
class SelfiesDataset(Dataset):
    # 初期化時にSELFIESリストを数値化して保存
    def __init__(self, selfies_list, max_len=120):
        self.data = [encode_selfies(s, max_len) for s in selfies_list]

    # データ数とデータ取得の定義
    def __len__(self):
        return len(self.data)

    # i番目のデータをテンソルで返す
    def __getitem__(self, i):
        return torch.tensor(self.data[i], dtype=torch.long)

# データセット、データローダーの定義
dataset = SelfiesDataset(selfies_list, max_len=120)
loader  = DataLoader(dataset, batch_size=64, shuffle=True)

最初のselfies_tokens関数は、SELFIES文字列をトークン単位に分割します。出力がリストとなり、ニューラルネットワークが扱いやすくなります。次のencode_selfies関数の中で使われています。

encode_selfies関数は、数値化+パディングを行います。開始トークンstartと終了トークンendを追加し、デコーダに「ここから生成」「ここで終了」を明示しています。そして先ほど作ったstoi辞書を使って、各トークンを数値IDに変換します。さらに、分子ごとに長さが違うので、最大長max_lenに揃えるためにpadを追加します。この関数はdataset定義の中で使われています。

まだまだ定義が続きますが(ひー🐯🐱)、とりあえずコピペ実行で動くはずなので、中身が理解できなくてもまず動かしてみて何が起こるか見てみてください:runner_tone2:

2. VAEモデル、サンプリング方法、損失関数の定義

■ VAE

エンコーダ+潜在空間+デコーダという構成になります。

エンコーダに入力された分子群により潜在空間が構築され、この潜在空間からサンプリングすることにより、デコーダが新規分子を生成するという流れです。

↓よく見るVAEの構成図、図では画像が入力だが、今回は文字(SELFIESです)、こーゆのを作っていきます。
image.png
図. 変分オートエンコーダ VAE Variational AutoEncoder の構造(https://cvml-expertguide.net/terms/dl/deep-generative-model/vae/)

class SelfiesVAE(nn.Module):
    """
    SELFIESトークン列を対象にしたVAE
    """
    def __init__(self, vocab_size, hidden=256, latent=64, max_len=120):
        super().__init__()
        self.max_len = max_len
        # ★ padding_idx を設定
        self.embed = nn.Embedding(vocab_size, hidden, padding_idx=pad_idx)

        # Encoder
        self.encoder_rnn = nn.GRU(hidden, hidden, batch_first=True)
        self.mu = nn.Linear(hidden, latent)
        self.logvar = nn.Linear(hidden, latent)

        # Decoder
        self.decoder_rnn = nn.GRU(hidden, hidden, batch_first=True)
        self.fc_out = nn.Linear(hidden, vocab_size)
        self.latent_to_hidden = nn.Linear(latent, hidden)

    def encode(self, x):
        x_emb = self.embed(x)              # [B,T,H]
        _, h = self.encoder_rnn(x_emb)     # h: [1,B,H]
        h = h[-1]                          # [B,H]
        return self.mu(h), self.logvar(h)

    def reparam(self, mu, logvar):
        eps = torch.randn_like(mu)
        return mu + eps * (0.5 * logvar).exp()

    def forward(self, x, word_dropout_p: float = 0.15):
        """
        word dropout(teacher forcing入力の一部を<pad>に)
        目的:デコーダがzを見ないで「入力コピー」し続けるのを防ぐ
        やること:forward()内で、学習中だけdec_inp_idsの一部を<pad>に変換

        学習用: teacher forcing でロジットを返す
        x: [B,T](<start> ... <end> + <pad>)
        予測対象は x[:,1:]、デコーダ入力は x[:, :-1]
        """
        mu, logvar = self.encode(x)
        z = self.reparam(mu, logvar)
        h0 = self.latent_to_hidden(z).unsqueeze(0)      # [1,B,H]

        # 教師強制
        dec_inp_ids = x[:, :-1]                         # [B, T-1]、デコーダ入力、正解トークンを次の入力として与える
        dec_tgt_ids = x[:, 1:]                          # [B, T-1]、デコーダの教師信号、出力すべき正解ラベル

        # ★ word dropout(学習時のみ)
        if self.training and word_dropout_p > 0:
            # pad / end はなるべく触らない(startは先頭1ステップ目だけなので気にしすぎなくてOK)
            keep_mask = (dec_inp_ids != pad_idx) & (dec_inp_ids != end_idx)
            drop_rand = torch.rand_like(dec_inp_ids.float())
            drop_mask = (drop_rand < word_dropout_p) & keep_mask
            dec_inp_ids = dec_inp_ids.masked_fill(drop_mask, pad_idx)

        dec_inp = self.embed(dec_inp_ids)
        dec_out, _ = self.decoder_rnn(dec_inp, h0)
        logits = self.fc_out(dec_out)
        return logits, dec_tgt_ids, mu, logvar

    @torch.no_grad()
    def sample(self, z, max_len=None, p=0.9, temperature=1.05, use_topp=True):
        if max_len is None:
            max_len = self.max_len
        h = self.latent_to_hidden(z).unsqueeze(0)
        B = z.size(0)
        cur = torch.full((B, 1), start_idx, dtype=torch.long, device=z.device)
        cur_emb = self.embed(cur)
        out_ids = []
        
        for _ in range(max_len-1):
            dec_out, h = self.decoder_rnn(cur_emb, h)
            logits = self.fc_out(dec_out)  # [B,1,V]
            if use_topp:
                next_id = sample_top_p(logits, p=p, temperature=temperature)  # ←確率サンプル
            else:
                next_id = logits.argmax(-1)                                   # ←貪欲
            out_ids.append(next_id)
            if (next_id == end_idx).all(): break
            cur_emb = self.embed(next_id)
        return torch.cat(out_ids, dim=1) if out_ids else cur
    
def sample_top_p(logits, p=0.9, temperature=1.0):
    """
    これは 「潜在ベクトルから分子を作るときのデコーダ側で、次の文字(トークン)をどう選ぶか」 を制御する関数です。
    デコーダが吐いた次のトークンの確率分布(logits)から、どのトークンを選ぶか決める部分
    Greedy(確率最大のトークンを選ぶ, argmax)なら毎回一番確率が高いトークンをとる→単調になりがち
    Top-p(nucleus)サンプリングは、「累積確率がpを超えるまでの上位候補だけを残し、その中から確率的に選ぶ」という方法
    """

    # 温度スケーリング、temp > 1で分布が平ら→多様性増える、temp < 1で分布が尖る→多様性減る
    if temperature != 1.0:
        logits = logits / temperature

    # 確率化&ソート
    probs = F.softmax(logits, dim=-1)            # [B,1,V]、確率に変換、logtisはNNの出力層から出てくる確率に変換する前の生のスコア
    sorted_probs, sorted_idx = torch.sort(probs, dim=-1, descending=True)
    cumsum = torch.cumsum(sorted_probs, dim=-1)  # 確率の大きい順に並べて、累積和をとる

    # Top-pフィルタリング、上位から累積確率がp(例:0.9)に達するまで残す、それ以外は確率0にして正規化
    mask = cumsum <= p
    mask[..., 0] = True                          # 少なくとも最大確率トークンは残す
    filtered = torch.where(mask, sorted_probs, torch.zeros_like(sorted_probs))
    filtered = filtered / filtered.sum(dim=-1, keepdim=True)

    # サンプリング、残した候補の中から確率に従ってランダムに選ぶ、next_id が次のトークンID
    next_rel = torch.multinomial(filtered.squeeze(1), 1)     # [B,1]
    next_id  = sorted_idx.gather(-1, next_rel.unsqueeze(-1)).squeeze(-1)  # [B,1]→[B,1]
    return next_id

VAEモデルの仕組みとコードの内容を説明します

まずエンコーダ側では、

→埋め込み層でトークンIDを連続ベクトルに変換
→GRUに入力し、最後の隠れ状態 h を取得
→そこから潜在分布の 平均 μ と 分散 logσ² を計算

次に、Reparameterization Trickを使って、
z = μ + σ ⊙ ε (εは標準正規乱数)として潜在ベクトルzをサンプリングします。このトリックにより誤差逆伝播が可能になります。

デコーダへの入力は2種類あります。

  1. 潜在ベクトル z
    エンコーダから出てきた潜在変数zをlatent_to_hidden(z)に変換して、初期隠れ状態に与えます。これにより「どんな分子を作るか」という全体像を制御します

  2. 教師強制
    「次トークンの正解」を入力として強制的に入れてしまいます。これにより学習が安定します

学習時にデコーダは教師強制により「前の正解トークン」を入力として次のトークンを予測します。しかし、これを100%やってしまうと潜在ベクトルzを無視して、入力をそのままコピーするだけになってしまいがちです。それだと、ただのオートエンコーダとなってしまい、posterior collapseと呼ばれる現象が起こってしまいます。

これだと新規分子をつくることができず、既存分子が再度でてくるだけです。。。:scream:

そこで、Word Dropoutと呼ばれる工夫を施しています。入力トークンの一部を強制的にpadに置き換えてしまいます。そうすることで、デコーダは「正解入力」が一部欠損していることになり、その穴を埋めるために潜在ベクトルzの情報を参照せざるを得ない状態になります。これにより、zに情報をしっかり乗せるように学習が進むわけです。

■ 潜在空間からのサンプリング方法

学習が完了すると、実際に分子を生成するときにデコーダは改めて「次のトークン」を1つずつ予測していくわけですが、このときのサンプリング方法(選び方)にはいくつか方法があります。

  • Greedy(argmax)
    毎回「最も確率が高いトークン」を選ぶ方法
    →安定して同じ分子を出すが、多様性が出にくい

  • Top-p sampling(nucleus sampling)
    「確率の大きい順に並べて、累積確率がp(例:0.9)を超えるまで残す」
    →その候補からランダムに選ぶ、ありえない候補を捨てつつ、ランダム性があるので多様性が生まれる

  • temperature
    softmax前の確率分布を「温度」で調整
    T > 1 : 分布が平らになり、多様性アップ
    T < 1 : 分布が尖り、確率最大のトークンに寄る(保守的)

コードでいうと、@torch.no_grad()以下がサンプリングに該当します。

今回はTop-p x temperatureと組み合わせて使うことで、「それなりに正しいが多様性のある分子」を作りに行ってます。sample_top_p関数がそれに該当します。関数の中でtemperatureを定義しています、この値をいろいろいじってみて、生成する分子に変化があるか見てみてください。

ランダムサンプリングをしたい場合は、use_topp=Falseにすれば実行できます。

■ 損失関数

最後に損失関数を定義します。2つ定義しており、1つ目は標準的なもの。ただ、これだとうまくいかなかったので、修正したものが2つ目になります。

(1) 標準的なβ-VAE損失関数
損失は 再構成誤差KLダイバージェンス の和で表されます。

image.png

ここで、

  • $\mathrm{CE}(x, \hat{x})$ : Cross Entropy 損失
  • $q_\phi(z|x) = \mathcal{N}(\mu, \mathrm{diag}(\sigma^2))$
  • $p(z) = \mathcal{N}(0, I)$

KLの具体式は、

$$
D_{\mathrm{KL}}\big[q(z|x),|,p(z)\big]
= -\tfrac{1}{2} \sum_{i=1}^d
\Big( 1 + \log\sigma_i^2 - \mu_i^2 - \sigma_i^2 \Big)
$$

コードで表現すると、

def vae_loss(logits, tgt_ids, mu, logvar, beta=1.0, weight=None, label_smoothing=0.1):
    """
    通常の損失関数
    """
    B, Tm1, V = logits.size()
    recon = F.cross_entropy(
        logits.reshape(B*Tm1, V),
        tgt_ids.reshape(B*Tm1),
        ignore_index=pad_idx,
        weight=weight,
        label_smoothing=label_smoothing
    )
    kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
    return recon + beta*kld, recon, kld
  • 再構成誤差 recon
    logits: 各時刻の語彙分布の“生スコア”
    tgt_ids: 教師強制の正解ID
    F.cross_entropy でシーケンス全体(B×(T-1) トークン)に対する平均CEを計算
    ignore_index=pad_idxpad は損失に入れない
    weight=weight はクラス重み(例:endに重みを足して文をきちんと終わらせやすくするなど)
  • KLダイバージェンス kld
    潜在分布 q(z|x)=N(μ, σ²) と 事前分布 p(z)=N(0, I) のKL
    -0.5 * mean(1 + logσ² − μ² − exp(logσ²))として、バッチ平均にする
  • 全体損失 loss = recon + β·kld
    β を大きくすると 潜在の規則化が強まり 構造がきれいに、ただし再構成は難しく
    β を小さくすると 再構成重視、ただし z が役立たずになりやすい(posterior collapse)

label_smoothingはクロスエントロピー損失において、正解ラベルに100%の確率を置かずに、少し別のクラスにも分布を持たせるテクニックです。通常だと、正解クラスが1.0となるところを、label_smoothing=0.1とすると、正解クラスが0.9となります。過学習になるのを防ぎ、潜在空間に多様性が生まれます。しかし大きすぎると、正解と不正解の差が曖昧になり、学習が進みにくくなるので注意。

(2) Free-bits付き損失関数

image.png

ここで、

  • $D_{\mathrm{KL}}^{(i)}$ は潜在次元 $i$ に対応するKL項
  • $\tau$ は「free-bits」パラメータ(例: 0.5〜1.0 nats)

最大の目的はposterior collapse 対策です。デコーダが強力だと、z を無視してKL→0になりがちですが、free-bits は 各潜在次元のKL に下限 τ を与える発想となります。

コードは以下になります。torch.clamp(min=tau) で 各次元のKLを最低でも τ とみなして損失に入れており、これにより 「KL=0 で済ませる」解が不利 になり、z に情報をのせる方向へ学習が進みます。

def vae_loss_freebits(logits, tgt_ids, mu, logvar, beta, tau=0.5, weight=None, label_smoothing=0.0):
    """
    学習が進むにつれて、kld=0 = 潜在空間が無意味になっている(posterior collapse)が起こってしまう。
    これは、デコーダが強力すぎて、潜在 z に頼らなくても「teacher forcing」で入力系列をコピーできる。
    その結果:エンコーダは μ=0, logvar=0 に収束 → サンプリングしても z はほぼ 0 ベクトル。つまり 潜在空間が死んでる状態。

    KL の各次元が最低 ε (例: 0.5) になるようにする。→ z に情報を載せざるを得なくなる。
    tau: 1次元あたりのKL下限(nats)。例: 0.5〜1.0 を試す
    """
    B, Tm1, V = logits.size()
    # 再構成
    rec = F.cross_entropy(
        logits.reshape(B*Tm1, V),
        tgt_ids.reshape(B*Tm1),
        ignore_index=pad_idx,
        weight=weight,
        label_smoothing=label_smoothing
    )
    # 次元ごとのKL([B, Z])
    # KL_i = -0.5 * (1 + logvar_i - mu_i^2 - exp(logvar_i))
    kld_per_dim = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())

    # free bits: 各次元のKLに下限tauを課す(平均の前に適用)
    kld_per_dim_clipped = torch.clamp(kld_per_dim, min=tau)  # 下限をtauとして、0に落ち込むのを防ぐ

    # バッチ平均 & 次元平均
    kld_fb = kld_per_dim_clipped.mean()

    loss = rec + beta * kld_fb
    return loss, rec, kld_fb

ここまでお疲れさまでした:raised_hand:
定義が長くなりましたが、いよいよ学習開始です:nerd:

損失関数の中身を見ると、あれっ?:open_mouth:となることがあるかと思います。そもそも、VAEの目的は「似ているけど違う分子を作りたい」ということでした。ただ損失関数の構成として、再構成誤差はエンコーダへの入力とデコーダの出力が一致するほど小さくなるものです。KLダイバージェンスは、エンコーダが推定した分布と標準正規分布が一致すると小さくなります。

つまり損失関数が小さくなりすぎるということは、

  • エンコーダ入力とデコーダ出力が一致 ⇒ 入力をコピーしてるだけ
  • KLダイバージェンスが極端に小さい ⇒ 潜在変数が無視されてしまい、サンプリングしても同じ分子しか出ない(posterior collapse)

となるわけで、VAEの目的を満たしません。だから小さければよいわけではなく、バランスが大事ということになります。このバランスをとるために、

  • β-VAE(KLの重みを調整)
  • KLアニーリング(最初は無視して後から効かせる、後で出てきます)
  • Free-bits(KLがゼロにならないように下限を設ける)

などの工夫を施しているわけです:ok_hand:

3. 学習と分子生成

■ 学習

下記コードでモデルに学習させていきましょう。学習条件を設定し、今まで定義したものを呼び起こしていきます。

# 学習条件の設定
USE_ANNEAL  = True      # True: KLアニーリング有効、False: βは常に1.0
USE_FREEBITS = True     # Freebitsを使うかどうか、
tau = 1.0               # まずは0.5から、弱ければ0.25、強ければ0.75-1.0を試す
epochs = 500              
kl_warmup_epochs = 30   # KLアニーリング期間
clip_grad = 1.0         # 勾配クリップ、勾配爆発対策、勾配ノルムがclip_grad以上になったらスケーリングする

# モデル構築
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SelfiesVAE(vocab_size=len(tokens), hidden=256, latent=64).to(device)
opt = torch.optim.Adam(model.parameters(), lr=1e-3)
weight = torch.ones(len(tokens), device=device)
weight[end_idx] = 1.3  # EOSを少し出しやすく

# 学習の履歴
history = {"epoch": [], "loss": [], "rec": [], "kld_eff": [], "kld_raw": []}

for epoch in range(1, epochs+1):
    model.train()

    # アニールを使うかどうか、使う場合は徐々に0→1
    beta = min(1.0, epoch / kl_warmup_epochs) if USE_ANNEAL else 1.0

    total_loss = total_rec = total_kld_eff = total_kld_raw = 0.0

    for batch in tqdm(loader, leave=False):
        batch = batch.to(device)        # [B, T]
        logits, tgt_ids, mu, logvar = model(batch)
        
        # Freebitsを使うかどうか
        if USE_FREEBITS:
            loss, rec, kld_eff = vae_loss_freebits(logits, tgt_ids, mu, logvar, beta=beta, tau=tau)
            # 監視用に raw KL も計算(勾配不要)
            with torch.no_grad():
                kld_raw = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp()).mean()
        else:
            loss, rec, kld_eff = vae_loss(logits, tgt_ids, mu, logvar, beta=beta)
            kld_raw = kld_eff  # そのまま

        opt.zero_grad()
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
        opt.step()

        total_loss    += loss.item()
        total_rec     += rec.item()
        total_kld_eff += kld_eff.item()
        total_kld_raw += kld_raw.item()

    # 損失関数などの平均
    n = len(loader)
    avg_loss = total_loss / n
    avg_rec  = total_rec  / n
    avg_kld_eff = total_kld_eff / n
    avg_kld_raw = total_kld_raw / n  # USE_FREEBITS=False のときは kldと同義

    # 履歴に追加
    history["epoch"].append(epoch)
    history["loss"].append(avg_loss)
    history["rec"].append(avg_rec)
    history["kld_eff"].append(avg_kld_eff)
    history["kld_raw"].append(avg_kld_raw)

    
    # 10エポックごとに表示
    if epoch % 10 == 0:
        if USE_FREEBITS:
            print(f"Epoch {epoch:03d} | beta={beta:.2f} | tau={tau:.2f} | "
                f"loss={avg_loss:.3f} | rec={avg_rec:.3f} | "
                f"kld_eff={avg_kld_eff:.3f} | kld_raw={avg_kld_raw:.3f}")
        else:
            print(f"Epoch {epoch:03d} | beta={beta:.2f} | "
                f"loss={avg_loss:.3f} | rec={avg_rec:.3f} | kld={avg_kld_eff:.3f}")

まず、学習の条件を見てみましょう。

  • USE_ANNEAL : KLアニーリングを使うかどうかです。β を徐々に 0→1 に上げるという手法です。学習の初期は再構成能力が低いため、KL項の影響が強いと、潜在分布をとにかく標準正規分布に合わせようとして、posterior collapseが起きやすくなってしまいます。なので、まずは再構成誤差を減らして学習を安定化するのに注力するため、ベータを段階的に上げるのがKLアニーリングです。

  • USE_FREEBITS : Free-bits付の損失関数を使うかどうかを選べます。ちなみに使わないと変な分子しか生成されません。τ の目安として、0.25〜1.0 nats/次元で調整してみてください。大きいほどzに必ず情報を載せるので新規性は上がりますが、学習分子から外れやすくはなります。

  • weight : 損失関数における再構成誤差の中で使用されます。意図的にEOS(endトークン)に荷重をかけることで、生成時に「文を途中で終える」確率を上げています。これにより無限にだらだら構造を拡張することを防ぎます。

他にも勾配クリップやアニーリング期間などのパラメータがありますが、正直この辺は結果を見ながら修正していくことになります。

これらの条件を設定し、モデル、最適化関数、損失関数を定義してしまえば、後は他のニューラルネットワークと同じです。

学習結果を可視化しましょう:writing_hand:

# 可視化
epochs = history["epoch"]

plt.figure(figsize=(7,4))
plt.plot(epochs, history["loss"], label="loss")
plt.plot(epochs, history["rec"],  label="rec")
plt.plot(epochs, history["kld_eff"], label="kld_eff")
# (freebits時の参考)素のKLも見たければ:
plt.plot(epochs, history["kld_raw"], label="kld_raw", linestyle="--")

plt.xlabel("epoch")
plt.ylabel("value")
plt.title("Training curves")
plt.legend()
plt.tight_layout()
plt.show()

# 参考:再構成Perplexity(= exp(rec))も見たい場合
ppl = [math.exp(r) for r in history["rec"]]
plt.figure(figsize=(7,4))
plt.plot(epochs, ppl, label="perplexity")
plt.xlabel("epoch"); plt.ylabel("PPL")
plt.title("Reconstruction Perplexity")
plt.legend(); plt.tight_layout(); plt.show()

image.png

lossは損失関数で再構成誤差とKLダイバージェンスの和でした。rec(再構成誤差)が学習が進むにつれて減っています。また、tauを1に設定したのでkld_effは1のままで、これ以上は下がらないようになっています。これによりkld_eff=0となるのを防げています。よって、

lossが減っているため学習は安定し、kldも0となっていないので潜在空間が意味を持ち、reconstructionとKLのバランスがとれた!と判断していいでしょう:v:

■ 分子生成

それでは、やっと本題の新規分子生成です!!また相変わらずコード長いですが:skull:、コピペでOKです。生成する分子の数を変更したければ、N_GENの数字を生成したい数に変更してください(今は100分子にしてます)。

# ==== 設定 ====
USE_NEIGHBOR = True     # True: 近傍サンプリング, False: ランダムサンプリング
N_GEN = 100             # 生成分子数
noise_scale = 0.05      # 近傍に足すノイズの強さ、0.05など小さくすると学習データにほぼ近い、0.2などにすると多様性が増える
seed = 1234             # 乱数seed
g_gen = torch.Generator(device=device).manual_seed(seed)

# ==== フィルタ設定 ====
USE_FILTER = True   # True: 物理化学的なフィルタを適用, False: フィルタせず全通過

def ok_material_molecule(m):
    try:
        Chem.SanitizeMol(m)
    except:
        return False
    atom_symbols = [a.GetSymbol() for a in m.GetAtoms()]
    allowed = {"C","H","N","O","S","F","Cl","Br","I","Si","B"}  # 材料向けに調整
    if not all(sym in allowed for sym in atom_symbols):
        return False
    n_atoms = m.GetNumAtoms()
    if not (20 <= n_atoms <= 150):
        return False
    return True

# 潜在ベクトルの準備
if USE_NEIGHBOR:
    # 複数バッチからmuを集める(学習データから代表点を確保)
    mu_bank = []
    for i, batch in enumerate(loader):
        batch = batch.to(device)
        mu, _ = model.encode(batch)
        mu_bank.append(mu)
        if len(mu_bank) * batch.size(0) > 2000:  # 2000点くらいで十分
            break
    mu_bank = torch.cat(mu_bank, dim=0)

    # そこからランダムに N_GEN 個選んでノイズを足す
    idx = torch.randint(0, mu_bank.size(0), (N_GEN,), device=device, generator=g_gen)
    z_base = mu_bank[idx]
    noise = torch.randn(z_base.shape, device=device, generator=g_gen)
    z = z_base + noise_scale * noise
    print(f"[近傍サンプリング] mu_bankから {N_GEN} 個のzを取得 (noise_scale={noise_scale})")
else:
    z = torch.randn(N_GEN, 64, device=device, generator=g_gen)
    print(f"[ランダムサンプリング] {N_GEN} 個のzを生成")

# ==== 分子生成 ====
model.eval()
with torch.no_grad():
    """
    温度スケーリング:デコーダが出すlogitsを割るパラメータ
    temperature < 1 →分布が尖る(確率の高いトークンが強調される←出力が決まりやすい)
    temperature > 1 →分布が平らになる(低確率のトークンも選ばれやすい←多様性が増える)

    p(top-p/nucleusサンプリング)
    「確率の大きいトークンから順に累積確率を足していき、pに達するまでの候補だけを残す」方式
    小さいp(0.7-0.8) : 候補が狭くなり保守的→生成が安定
    大きいp(0.9-0.95) : 候補が広がり多様性増加→変な分子も混じる
    """
    gen_ids_batch = model.sample(z, p=0.75, temperature=0.6, use_topp=True)  # 温度を低く、pを小さくすると保守的な分子に

def ids_to_selfies(ids):
    toks = []
    for t in ids:
        if t == end_idx: break
        if t not in (pad_idx, start_idx):
            toks.append(itos[t])
    return "".join(toks)


mols, smiles_list = [], []
for i in range(gen_ids_batch.size(0)):
    gen_sfs = ids_to_selfies(gen_ids_batch[i].tolist())
    gen_smi = sf.decoder(gen_sfs)
    mol = Chem.MolFromSmiles(gen_smi)

    if mol is None:
        continue
    
    # ==== フィルタ適用 or スキップ ====
    if USE_FILTER:
        if not ok_material_molecule(mol):
            continue

    mols.append(mol)
    smiles_list.append(gen_smi)

print(f"Valid ratio (after filter={USE_FILTER}): {len(mols)}/{gen_ids_batch.size(0)}")

# 可視化
img = Draw.MolsToGridImage(mols, molsPerRow=10, subImgSize=(200,200), legends=[f"{i}" for i in range(len(mols))])
display(img)

image.png

100分子生成されました🎉
ただこれだと、どの分子が新規なのかわからないので、「既存 or 新規」かを分けて描写します。

# 1) 学習データ SELFIES → 正規化SMILES セット化(初回だけ)
from rdkit import Chem
import selfies as sf

train_smiles_set = set()
train_idx_by_smi = {}  # canonical smi -> 最初に見つかった学習側index

for i, sfs in enumerate(selfies_list):
    try:
        smi = sf.decoder(sfs)
        m = Chem.MolFromSmiles(smi)
        if not m:
            continue
        can = Chem.MolToSmiles(m, canonical=True)  # 立体も区別(不要なら isomericSmiles=False)
        if can not in train_idx_by_smi:
            train_idx_by_smi[can] = i
        train_smiles_set.add(can)
    except Exception:
        pass

print("train uniques:", len(train_smiles_set))

# 2) 生成分子の正規化SMILESを作って一致判定
gen_can = []
dupes = []   # (gen_idx, canonical_smi, train_idx)
novel = []   # (gen_idx, canonical_smi)

for j, m in enumerate(mols):
    can = Chem.MolToSmiles(m, canonical=True)
    gen_can.append(can)
    if can in train_smiles_set:
        dupes.append((j, can, train_idx_by_smi[can]))
    else:
        novel.append((j, can))

print(f"生成: {len(mols)} / 完全一致: {len(dupes)} / 新規: {len(novel)}")
# 例:先頭数件だけ表示
for j, can, i_tr in dupes[:10]:
    print(f"[GEN {j}] == TRAIN {i_tr}  {can}")

# 3) もし重複/新規だけでグリッド可視化したい場合(任意)
from rdkit.Chem import Draw
from IPython.display import display

# 重複だけ
dup_mols = [mols[j] for j, _, _ in dupes]
if dup_mols:
    img_dup = Draw.MolsToGridImage(
        dup_mols, molsPerRow=8, subImgSize=(200,200),
        legends=[f"gen {j} = train {i_tr}" for j,_,i_tr in dupes]
    )
    display(img_dup)

# 新規だけ
nov_mols = [mols[j] for j,_ in novel]
if nov_mols:
    img_nov = Draw.MolsToGridImage(
        nov_mols, molsPerRow=8, subImgSize=(200,200),
        legends=[f"gen {j}" for j,_ in novel]
    )
    display(img_nov)
train uniques: 238
生成: 100 / 完全一致: 66 / 新規: 34
[GEN 2] == TRAIN 46  c1csc(-c2ccc(-c3ccc(-c4ccc(-c5cccs5)s4)c4nsnc34)s2)c1
[GEN 4] == TRAIN 5  c1ccc2c(c1)oc1cc(-c3ccc(-c4ccc5c(c4)oc4ccccc45)c4c3N=S=N4)ccc12
[GEN 6] == TRAIN 51  c1ccc(-c2ccc(-c3ccc(-c4ccc(-c5ccccc5)cc4)c4nsnc34)cc2)cc1
[GEN 7] == TRAIN 57  FC(F)(F)c1ccc(-c2ccc(-c3ccc(C(F)(F)F)cc3)c3nsnc23)cc1
[GEN 8] == TRAIN 207  CC(C)(C)c1ccc(N(c2ccccc2)c2ccc3ccc4c(N(c5ccccc5)c5ccc(C(C)(C)C)cc5)ccc5ccc2c3c54)cc1
[GEN 9] == TRAIN 25  COc1ccc(N(c2ccc(OC)cc2)c2ccc(-c3ccc(-c4ccc(-c5ccc(-c6ccc(N(c7ccc(OC)cc7)c7ccc(OC)cc7)cc6)s5)c5nsnc45)s3)cc2)cc1
[GEN 10] == TRAIN 183  c1ccc2c(c1)Oc1cccc3c1B2c1ccccc1O3
[GEN 11] == TRAIN 81  c1ccc(C(=C(c2ccccc2)c2ccc(-c3ccc(-c4ccc(-c5ccc(-c6ccc(C(=C(c7ccccc7)c7ccccc7)c7ccccc7)cc6)s5)c5nsnc45)s3)cc2)c2ccccc2)cc1
[GEN 13] == TRAIN 68  c1ccc(C(c2ccc(-c3ccc(-c4ccc(C(c5ccccc5)c5cccc6ccccc56)cc4)c4nsnc34)cc2)c2cccc3ccccc23)cc1
[GEN 16] == TRAIN 104  CC1=C(c2cc3cc(C)ccc3s2)C(C)=[N+]2C1=C(c1ccc(C)cc1)c1c(C)c(-c3cc4cc(C)ccc4s3)c(C)n1[B-]2(F)F

↓生成された分子のうち既存分子と一致したもの(既にデータセットに存在していたもの)
image.png

完全新規(VAEによって生み出されたオリジナル分子)
image.png

34個の分子が新規分子として生成されています。中には「明らかにこんなん作れんやろ笑」みたいな怪しいやつもいますが、現実的なやつもちらほら混ざっています。見てみますか!

gen 36 こんなん誰か作ってそうやな:grin:
image.png

gen 3 お!そこに三重結合いれちゃう?:hushed:
image.png

gen 85 失敗作、こんなの作れまちぇん:alien:
image.png

てな感じで遊んでみてください笑
好きな分子をindexで抽出して描写したい場合は、下記コード使ってください。

from rdkit.Chem import Draw
from IPython.display import display

def show_gen(idx, size=(350, 300)):
    """生成配列 mols の idx 番を1枚だけ表示"""
    if idx < 0 or idx >= len(mols):
        raise IndexError(f"idx out of range (got {idx}, len={len(mols)})")
    display(Draw.MolToImage(mols[idx], size=size))

# 表示したい分子を数字で指定
show_gen(85)

まとめ&所感

VAEを使って新規分子を生成させてみました。コードはチャッピーに描いてもらいましたが、意外と動いたなというのが正直な感想です。

特に Free-bits loss と組み合わせると posterior collapse 回避 に有効でした 🙆‍♂️

生成する分子をもう少しフィルタリングしたり、損失関数に類似度を判定するものをいれたりすると、もう少しマシになるかもしれません。また改善あったら記事書こうと思います🐵

長文読んでいただきありがとうございました!!

引用

  • AI分子生成の導入と基本手法の紹介 MI-6
    https://milab.mi-6.co.jp/article/t0028
    もともとのきっかけはこの記事です。図も引用させていただきました。分かりやすい。
  • VAE (Variational Autoencoder, 変分オートエンコーダ)
    https://cvml-expertguide.net/terms/dl/deep-generative-model/vae/
    VAEに関してはこちらのサイトを参考にさせていただきました。詳細も丁寧にまとめられており、VAEに限らずニューラルネットワーク全般の内容がまとめられています。E資格対策の時もお世話になりました。
  • Automatic Chemical Design Using a Data-Driven Continuous Representation of Molecules
    https://pubs.acs.org/doi/10.1021/acscentsci.7b00572
    VAEで分子生成の基礎となったであろう論文。著者らはSMILESで生成を行っています。しかもVAEに予測器も搭載させており、生成した新規分子の物性予測も行っています。今度はこれに挑戦したい。
2
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
2
0

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?