0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

CNNを使って構造式(画像)からlogpを予測する

Posted at

概要

オクタノール・水分配係数(logp)は疎水性の指標となる値です。疎水性は分子構造により変化するので、構造式をニューラルネットワークのインプットとすれば予測することが可能だと考えました。
コードはgithubに公開しています。 → https://github.com/hiroyakubo/image-logp

開発環境

  • Windows11 WSL2 Ubuntu20.04
  • python 3.6.15

環境構築

pythonライブラリをインストールするだけです。

pip install -r requirements.txt

コード解説

データセット

まず分子のSMILES一覧をCSVファイルとして用意しました。このSMILES記法で表現された分子をrdkitを使って画像に変換していきます。

    def create_image(self, size:tuple=(224, 224)) -> None:
        """Create image from smiles
        Parameters
        ------
        size : tuple
            image size
        """
        data = {}
        pbar = tqdm(enumerate(self.df["SMILES"].to_numpy()), total=len(self.df))
        for i, smiles in pbar:
            mol = Chem.MolFromSmiles(smiles)
            if mol is not None:
                idx = str(i).zfill(5)
                Draw.MolToFile(mol, os.path.join(self.image_dir, f"{idx}.png"), size=size)

                logp = Crippen.MolLogP(mol)
                data[idx] = [logp, smiles]

        with open(os.path.join(self.data_dir, f"{self.file_stem}_label.json"), "w") as f:
            json.dump(data, f, indent=4)

SMILES記法で表現された文字列をMolFromSmilesでMolオブジェクトに変換します。その後、Draw.MolToFileでMolオブジェクトを画像ファイルへと変換しています。インクリメントされるインデックス番号を付与して画像ファイル名としています。
この時、同時に各分子のlogpを算出しています。このままでは画像とlogp値、SMILESの対応がわからなくなってしまうので、それぞれの対応を記載したjsonファイルを保存しておきます。
00000.png00002.png00001.png
上の画像は作成した構造式の例です。rdkitを使って画像を作成すると官能基に色を付けてくれます。もしかすると、このように官能基に色がついていた方がニューラルネットワークとしても分子構造を理解しやすいかもしれません。

    def create_dataset(self, prepare:bool=False) -> None:
        """Create logp dataset
        Parameters
        ------
        prepare : bool
            if True, save image.
        """
        if prepare:
            self.create_image()

        with open(os.path.join(self.data_dir, f"{self.file_stem}_label.json"), "r") as f:
            label_data = json.load(f)

        image_files = [f for f in glob.glob(os.path.join(self.image_dir, "*.png"))]
        images, labels, self.smiles = [], [], []
        for image_file in image_files:
            idx = os.path.splitext(os.path.basename(image_file))[0]
            label = label_data[idx][0]
            image = cv2.imread(image_file)

            images.append(image)
            labels.append(label)
            self.smiles.append(label_data[idx][1])

        self.images = np.array(images).transpose(0, 3, 1, 2)
        self.labels = np.array(labels).reshape(-1, 1)

画像の保存は初回のみ必要です。画像を作成・保存するかどうかはcreate_dataset関数の引数prepareで決定しています。2回目以降で画像の作成・保存が必要ない場合はprepare=Falseを指定します。この関数内では画像およびlogp値の読み込みを行っています。最後にself.imagesをtransposeしているのは、使用するニューラルネットワークの入力形式に合わせるためです。

    def __len__(self):
        return len(self.images)


    def __getitem__(self, idx):        
        return self.images[idx], self.labels[idx]

torchのdatasetの形式に合わせるため__len__関数および__getitem__関数を作成します。

ニューラルネットワーク

def create_model(network:str, load_weight:bool=False):
    """Create pytorch model.
    
    Parameters
    ------
    network : str
        network name to use
    load_weight : bool
        load model weight
    
    Returns
    ------
    torchvision.models : pytorch model
    """
    model = torch.hub.load('pytorch/vision:v0.10.0', network, pretrained=True)
    num_ftrs = model.fc.in_features
    model.fc = torch.nn.Linear(num_ftrs, 1)
    model.to(device)

    if load_weight:
        model.load_state_dict(torch.load(model_path))
    
    return model

今回はCNNモデルをtorch.hubからロードして使用しています。引数networkに対応しているネットワーク名を入れることでモデルをロードすることができます。また、今回は回帰問題であるため最終層の出力を1つに変更しています。

学習

def train():
    os.makedirs(weight_dir, exist_ok=True)

    data_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    trainset = Dataset(data_dir, "train_input.csv")
    trainset.create_dataset(prepare=False)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size)
    valset = Dataset(data_dir, "test_input.csv")
    valset.create_dataset(prepare=False)
    valloader = torch.utils.data.DataLoader(valset, batch_size=batch_size)

    model = create_model(model_name)
    model = model.to(device)

    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)

    writer = SummaryWriter()

    print("start training.")
    for epoch in range(epochs):
        step = 0
        running_loss = 0.0
        for inputs, labels in trainloader:
            inputs = inputs.to(device).float()
            labels = labels.to(device).float()

            inputs = data_transform(inputs)

            optimizer.zero_grad()

            outputs = model(inputs)
            loss = criterion(outputs, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

            step += 1
        
        epoch_loss = running_loss / len(trainloader)
        mse = val(valloader, model, silent=True)

        writer.add_scalar("loss/train", epoch_loss, epoch)
        writer.add_scalar("loss/validation", mse, epoch)
        print(f"epoch: {epoch}, train loss: {epoch_loss:.2f}, validation loss: {mse:.2f}")

        torch.save(model.state_dict(), model_path)

transforms.Composeで適用する画像前処理を定義しています。適用している前処理はランダムな水平反転と標準化です。ここは垂直方向への反転や回転なども入れる余地があります。
今回は画像が欠損するような前処理は入れていません。なぜなら、分子構造中にlogpに大きく寄与する部分とあまり寄与しない部分があることが予想されるからです。もし分子構造が欠損するような前処理を入れてしまうと、例えばlogp値が大きく増加するような構造がないにもかかわらず大きなlogp値が正解データとなってしまう可能性があります。このような問題を起こさないために画像が欠損するような前処理は除いています。

可視化

def plot_distribution(
        pred:np.ndarray, dataset:torch.utils.data.Dataset
    ) -> None:
    """Plot carbon vs logp
    Parameters
    ------
    pred : np.ndarray
        predicted logp
    dataset : torch.utils.data.Dataset
        used dataset
    """
    with open(os.path.join(data_dir, f"{dataset.file_stem}_label.json"), "r") as f:
        _label = json.load(f)
    label = {}
    for v in _label.values():
        label[v[1]] = v[0]

    fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(9, 6))

    df = dataset.df.copy()
    df["logp"] = df["SMILES"].apply(lambda x: label[x])

    not_carbon = ["Ca", "Cr", "Co", "Cu", "Cs", "Cn", "Ce", "Cf"]
    df["c_count"] = df["SMILES"].apply(lambda x: x.count("C") - sum([x.count(nc) for nc in not_carbon]))
    df.plot.scatter(x="c_count", y="logp", ax=axes[0])
    axes[0].set_title("C vs logp")

    df_pred = pd.DataFrame(pred.flatten(), columns=["pred"])
    df_pred["SMILES"] = dataset.smiles
    df_pred["c_count"] = df_pred["SMILES"].apply(lambda x: x.count("C"))
    df_pred.plot.scatter(x="c_count", y="pred", ax=axes[1])
    axes[1].set_title("C vs prediction")
    
    y_max = max(df["logp"].max(), df_pred["pred"].max()) + 5
    y_min = min(df["logp"].min(), df_pred["pred"].min()) - 5
    axes[0].set_ylim(y_min, y_max)
    axes[1].set_ylim(y_min, y_max)
    
    # plt.show()
    plt.savefig("carbon_logp.png")

炭素数とlogpの関係を可視化するための関数です。
可視化のため、logpの値はあらかじめ作成しておいたjsonファイルから読み込みます。また、SMILES表記から炭素数を数え上げます。この時、Caなど炭素以外にもCの文字列を使用する原子があることに注意する必要があります。本コードでは、Cの総数を取得し、炭素以外のCを減算することで炭素数を算出しています。
さらに、同様の原理でlogpの予測値についてもpandas.DataFrame型でまとめます。
最後に2つのグラフのスケールを一致させてプロットします。スケールを同一のものにすることで比較しやすくしています。

def plot_layer(model, testset, ds, plot_num:int=5):
    feature_extractor = create_feature_extractor(model, {"inception3b": "feature"})
    testloader = torch.utils.data.DataLoader(testset, batch_size=1)
    fig, axes = plt.subplots(nrows=2, ncols=plot_num, figsize=(9, 6))
    step = 0
    for i, (inputs, _) in enumerate(testloader):
        inputs = inputs.to(device).float()
        inputs_aug = ds(inputs)
        feature = feature_extractor(inputs_aug)
        fm = feature["feature"]
        
        inp = inputs[0].to("cpu").detach().numpy().transpose(1, 2, 0).astype(int)
        fm = fm[0, 0, :, :].to("cpu").detach().numpy()

        axes[0, i].imshow(inp)
        axes[1, i].imshow(fm)

        step += 1
        if i >= plot_num - 1:
            break

    # plt.show()
    plt.savefig("layer.png")

中間層を可視化するための関数です。
中間層の抽出はtorchvision.models.feature_extractionのcreate_feature_extractorメソッドを使用すると簡単です。層の名前を引数で指定するだけで中間層を取得することができます。
ちなみに、層の名前を知りたい場合はprint(model)することでネットワーク構造を見ることができます。
ここで、最終層ではなくinception3bを指定している理由は画像が大きくて見やすいからという理由です。この辺りは検討の余地ありです。

結果・考察

tensorboard.png
学習時のlossの変化はtensorboardを使って記録および可視化を行っています。train loss, validation lossともに単調に減少しており学習ができていそうな結果となっています。
carbon_logp.png
lossの変化だけだと本当に学習がうまくいっているのかわからないので、炭素数とlogpの関係(上図左側)を図示しました。一般的に炭素数が多くなるほどlogpの値は大きくなっていきます。図示した結果もおおよそ右肩上がりの関係となっています。
一方、上図右側は炭素数とlogpの推論結果を図示したものとなります。炭素数50以下、予測値が-5~10ぐらいの範囲では、正解データとほぼ同じような傾向を示したプロットになっていることがわかります。しかし、図の右上の部分を見てみるとlogp=16ぐらいでプロットが頭打ちになっているように思えます。これは炭素数およびlogpがこの範囲に入るサンプル数が少ないかったことにより、うまくフィッティングできていないことが原因だと考えられます。
サンプル数が少ない範囲があることは仕方ないので、今回の実験結果としては「多くの分子をカバーできる範囲である炭素数50以下、予測値-5~10の範囲で良好な予測結果を得た」と言えると考えています。
layer.png
CNNでは、その中間層を可視化することによってニューラルネットワークがどの部分に着目して答えを導いているのかを考察することができます。
今回は可視化手法の紹介程度にとどめますが、中間層の出力を解析することによってlogpがどのような構造と関係があるのかを推察することができます。

まとめ

  • GoogleNetを使って構造式からlogpの値を予測するモデルを構築した
  • サンプル数が多い範囲ではおおよそ良好な予測結果を得ることができた
  • さらに中間層を検討することによりlogpの値に寄与する構造を知ることができる
0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?