17
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

【Kaggle】G2Net Gravitational Wave Detection コンペ振り返り

Last updated at Posted at 2021-10-05

はじめに

2021年7月1日〜2021年9月30日に開催されたKaggleのG2Netに社内メンバー( tktk, foo_foo, Keiichi Mase )+ kambehmw とチームを組んで参加しました。

19位のチームの解法はディスカッションに投稿しました。良ければご確認ください。

社内メンバー hirune924, SiNpcw が参加した別チームもあり、こちらは金メダルを獲得しました!!!解法のディスカッションはこちらです。

本記事ではコンペの概要と得られた知見をまとめます。

コンペ概要

疑似的に生成した重力波検出器のデータから、連星ブラックホールの合体による重力波を検出します。

  • データ: 3つの重力波検出器(LIGO Hanford、LIGO Livingston、Virgo)の時系列信号データ
    • train set: 560,000件, test set: 226,000件
    • データの全サイズ: 77.38 GB
  • タスク: 2クラス分類
    • train setのラベルは target=1が重力波あり、target=0が重力波なしでほぼ均等に存在
  • 評価指標: ROC AUC
  • データにノイズが非常に多い。重力波は非常に小さな波であるため、シグナルは検出器のノイズに埋もれている
    • ブラックホールの合体による重力波は特定の信号パターンがあります。周波数が徐々に高くなって振幅も大きくなる信号でチャープ信号と呼ばれます。以下の画像は重力波シグナルの例です。

image-20211004172627936.png
https://www.kaggle.com/c/g2net-gravitational-wave-detection/overview

データ

各データファイルは3つの時系列(各検出器ごとに1つ)が含まれており、2048Hzでサンプリングされた2秒間存在します。
つまり、要素数が[3, 2048 * 2]の配列データです
以下は target=1のデータを色付けしたplotですが、ノイズに埋もれているため重力波シグナルはわかりません。

image-20211003181157714.png

CQT(Constant-Q transform)

CQTは信号データをスペクトログラムに変換する手法です。
短時間フーリエ変換(STFT)と同じように信号に窓を掛けてフーリエ変換しますが、窓に含まれる各周波数成分の周期が同じになるように、着目する周波数によって窓の幅を変えます(低周波では窓幅を大きく、高周波では窓幅を小さくする)。
これにより低周波成分の周波数分解能を高く、高周波成分の時間分解能を高くできます(周波数分解能と時間分解能はトレードオフなので両方高くはできない)。
低周波数を細かく捉えたい時に向いている変換手法という程度でしか理解できていないので、詳しい方がいましたらコメントで教えて頂けるとありがたいです。

本コンペではCQTで作成したスペクトログラムをCNNで分類する方法が人気でした。

nnAudio

nnAudioはPyTorchを使ったオーディオ処理フレームワークです。
人気の公開notebookで使われました。

PytorchユーザならnnAudioのCQT1992v2でCQTを高速に実行できます。

nnAudioはpipでインストールできます。

$ pip install nnAudio

nnAudioのCQTと可視化のサンプルコードは以下です。

import torch
from nnAudio.Spectrogram import CQT1992v2
from matplotlib import pyplot as plt
import numpy as np

signal_names = ("LIGO Hanford", "LIGO Livingston", "Virgo")

def show_cqt(image, target=None, figsize=(15,7)):
    plt.figure(figsize=figsize)
    for j in range(image.shape[2]):
        plt.subplot(3, 3, j + 1)
        plt.title(f"{signal_names[j]}")
        plt.imshow(image[:,:,j], aspect="auto")
        plt.colorbar()
    plt.suptitle(f"target={target}")
    plt.tight_layout()
    plt.show()


waves = np.load("../input/g2net-gravitational-wave-detection/train/0/0/0/00000e74ad.npy") 
target = 1
waves = torch.from_numpy(waves).float()

# CQT
cqt_transform = CQT1992v2(sr=2048, fmin=20, fmax=1024, hop_length=64)
image = cqt_transform(waves)

image = image.transpose(0, 1).transpose(1, 2)
show_cqt(image, target)

image-20211003181052258.png

CQTで可視化しても重力波シグナルが見えないデータがほとんどですが、ノイズよりも重力波シグナルが強いデータもあります。

ちなみに、典型的な重力波シグナルは以下のようなバナナ状のカーブとして表示されます

image-20211003223445997.png

nnAudioはGPUに対応しておりモデルのレイヤーにすれば高速化できます。

以下はtimmを使ったサンプルコードです。

import torch
import torch.nn as nn
from nnAudio.Spectrogram import CQT1992v2
import timm


class CustomModel(nn.Module):
    def __init__(self, model_name, wave_transform_param, pretrained=False):
        super().__init__()
        self.wave_transform = CQT1992v2(**wave_transform_param)
        self.model = timm.create_model(model_name, pretrained=pretrained, in_chans=3, num_classes=1)

    def forward(self, x):
        waves = []
        for i in range(3):
            waves.append(self.wave_transform(x[:, i]))
        x = torch.stack(waves, dim=1)
        output = self.model(x)
        return output
    
model_name = "resnet34"
wave_transform_param = dict(sr=2048, fmin=20, fmax=1024, hop_length=8, window="flattop")
model = CustomModel(model_name, wave_transform_param)
model = model.to("cuda")

waves = np.load("../input/g2net-gravitational-wave-detection/train/0/0/0/00000e74ad.npy") 
waves = torch.from_numpy(waves).float()
waves = waves.unsqueeze(0)
waves = waves.to("cuda")
model(waves)  # tensor([[xxxx]], device='cuda:0', grad_fn=<AddmmBackward>)

コンペの取り組み

nnAudioのCQT1992v2でCQTのスペクトログラムを作成してtimmのモデルで学習回すのをひたすら繰り返しました。
ノイズが多い周波数帯を除くノイズ除去、ノイズを全ての周波数で一様にするwhitening、装置の解像度を補正するdeconvolution、重力波信号テンプレートを用いて微弱信号を検出するmatched filterなどをチームメンバーが試しましたがスコアは伸びませんでした。
コンペ中盤にLIGOの2チャネルを足したものを4番目のチャネルに、LIGOの2チャネルを引いたものを5番目のチャネルに追加するLIGOの干渉を強調する方法でスコア伸ばしました。
コンペ終盤に行った1DCNN、target=0の信号データをmixするaugmentation、信号データの標準化、Pseudo Labelingにより銀圏に残ることができました。
PublicLBは42位でしたが上位チームがshake down(コンペの最終時に公開される最終的な順位のPrivateLBがPublicLBより下がること)したため、PrivateLBは19位でした。
我々のチームがshake downしなかった理由はよくわかりませんが、スピアマンの順位相関係数が高すぎるモデルはアンサンブルから除いて、Cross Validationのスコアが最大になるsubmissionを選んだのが良かったのかもしれません。

性能改善で重要だったこと

1DCNN(1-Dimensional CNN)

  • 1位のチーム2位のチーム3位のチーム4位のチームはCQTよりも高性能な1DCNNのモデルを作成していました。本コンペでは1DCNNのモデルをうまく作れたかが勝負の分かれ目でした。(2位のチームのみwavenetも使っていました)
  • 1DCNNの学習はCQTのモデルに比べてはるかに高速であるため、実験を回しやすいメリットもありました。(計算機の性能に依存しますが、CQTは1epoch1-2時間程度必要なのに対して、6層の1DCNNは1epoch5分程度で学習できました。)
  • 個人的な感触としては前処理でハイパスフィルタやバンドパスフィルタを掛けることが1DCNNでは重要でした。
  • 1DCNNの実装は公開notebookが参考になりそうでした。

nnAudioのCQT1992v2のパラメータ

  • CQT1992v2のパラメータが大きく影響しており、我々のチームはCQTのストライドのサイズであるhop_lengthを小さくしてスペクトログラムの横幅を広げることが一番効果ありました。
  • 4位のチーム8位のチームは最適なパラメータを見つけているようでした。
  • CQT1992v2にはCQTカーネルを学習するtrainableパラメータもあり、18位のチームは学習を成功させていました。

augmentation

  • 我々のチームでも解法に書いてる target=0 + target=0, target=0 + target=1の信号データをmixするaugmentationは効果ありました。
  • 8位のチームはaugmentationがまとまっていてわかりやすかったです。
  • リストアップされた上位のチームではMixup、チャネルをランダムに入れ替えるRandom channel shuffle, LIGOのチャネルだけ入れ替えるLIGO channel shuffle, 波形を時間方向にずらすTime shiftなどを行っていました。

信号データの標準化

  • CQTを行う前に信号データを標準化することは重要でした。我々のチームは以下のように標準化をすることでPublicLB0.878以上のスコアを出せました。
waves = (waves + 3.4332e-20) / (3.5683e-20 + 3.4332e-20)

or

# 波の標準化パラメータ
norm_para = dict(
    mean=(6.90108482e-26, 5.11772679e-26, -1.38312479e-26),
    std=(np.sqrt(5.50605989e-41), np.sqrt(5.50555022e-41), np.sqrt(3.37945454e-42))
)
# 各チャネル標準化
waves_stand = np.stack([
    (waves[0] - norm_para["mean"][0]) / norm_para["std"][0],
    (waves[1] - norm_para["mean"][1]) / norm_para["std"][1],
    (waves[2] - norm_para["mean"][2]) / norm_para["std"][2] 
])
  • 標準化に必要なデータセット全体での平均や標準偏差の計算はこのnotebookを使用しました。tensorflowを使うことで高速に計算できます。

チャネル追加

  • 我々のチームでもLIGOの2チャンネルの加算減算をチャネルに追加するのが有効でした。
  • 3位のチーム13位のチームでもチャネルを追加する方法が使われていました。

Pseudo Labeling

モデルアンサンブル

  • 我々のチームでもCQTと1DCNNのアンサンブルは効果がありました。
  • 上位チームもアンサンブルを使っていました。モデルを大量に作りスタッキングをしているチームが多い印象でした。

その他

  • 1位の解法は合成データ生成と1DCNNでした。合成データは PyCBC / LALSuite で作成したそうです。
  • スペクトログラムの変換手法はMorletウェーブレットを使ったCWTVQTが公開されていました。
    • 私はCWT, VQTどちらもうまく学習できなかったので諦めましたが、 8位のチーム12位のチームによるとCWTはCQTと性能差はあまり無かったそうです。
  • CQTの学習が長時間かかるのが非常につらかったです。
    • V100で高速に学習を回したいと思いcolab pro+を契約しましたが、契約開始直後の24時間しかV100が割り当たらず使い物になりませんでした。
    • コンペ終盤チームメンバーの1人は https://vast.ai/ で学習を回していました。料金がやや安い代わりに接続が切れることがたびたびあったそうです。
    • 私もコンペ終盤は https://jarvislabs.ai/ のA100で学習を回しました。インターフェースはjupyterlabなので使いやすかったです(1時間約3ドルかかりましたが...)。
    • TPUで高速に学習するnotebookが公開されたこともあってか、ディスカッションを見ているとkaggleのTPUを使ってる方が多そうでした(コンペ終盤は混んでて使えないと言ってるディスカッションもありました)。
  • 実験管理にwandbを使いましたが、実験を途中のepochで打ち切ることも多く、管理が面倒になってあまり活用できませんでした。
  • チームメンバと週1で進捗報告をすることはモチベーション維持になり良かったです。
  • チーム名「We did it」は重力波検出の発表会見で出た有名なフレーズらしいです。

まとめ

G2Netは信号データのフレームワークやモデリング手法が学べたコンペでした。
スペクトログラムよりも信号データをそのまま1DCNNで解く方が強いのは意外でした。

本コンペで得られた知見を業務でも活用できればと思います。

最後までお読みいただきありがとうございました。

17
15
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
17
15

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?