3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

Axell超解像コンペ敗北の記録 ーpytorch高速化・GAN・SISRモデル調査を添えてー

Last updated at Posted at 2024-09-03

あいさつ

@hotate_2235と申します。
今回は、signateというコンペサイト上で実施されていた、Axell様主催の超解像コンペに参加して敗北した記録を最速で語っていきます。
主に

  • 行った工夫
    • pytorch関連
      • 工夫(高速化、その他)
      • GAN関連(基本的なコード、lossの設定の仕方、知見まとめ)
      • エラー対処
  • 調べたモデル
  • 感想

について、長々と記録がてら書いていくので、気になったところだけ目次から見ていただけたらと思います。

超解像ってどんなもの

超解像(Single Image Super Resolution)とは、おおざっぱに言って画像の解像度を上げる作業を指します。
ここでいう解像度とは、ピクセル数になります。
例えば、224x224の画像を、448x448の画像にするようなものです。
もともとは、bilinearやbicubicといった古典的なアップスケーリング手法がありましたが、最近ではそれをDeep learningモデルを使って行う、というのが主流になっているようです。

今回のコンペ

今回のコンペは、運営から配布された独自データセット(851+100枚)に対して、推論時間の条件を満たすONNXモデルを提出する、というもので、評価指標としては、PSNRが採用されました。

ですので、主な流れとしては

  • pytorch(or any)でモデルを組む
  • 配布データセットで学習する
  • モデルをONNXに変換する
  • 推論時間と精度を確認する

というものになっていました。

ここまでを踏まえて、ここからは私がどんな工夫をしたかなどをいろいろと記載してきます。(幸いモデルも分析結果も公表可能らしいので)

pytorch関連

pytorch関連では、主に

  • 工夫(高速化、その他)
  • エラー対処
  • GAN関連(基本的なコード、lossの設定の仕方、知見まとめ)

を書き連ねていきます。

環境は

OS Windows11 education(wsl2なし)
cpu core i7 11700F
GPU RTX 3060(12GB)

です。これに、以前紹介したノートPCでリモートアクセスする形をとっています。

pytorch-高速化

pytorchでの高速化関連についての話です。

dataloaderが遅すぎる件

今回、いただいたサンプルコードは、ざっくり書くと以下のようにして画像を読み込んで前処理をしていました。

dataloader
import PIL
import torchvision.transforms as transforms

image = PIL.Image.open("image/path") #画像読み込み
image_processed = transforms.Compose([
            transforms.~~~
        ])(image)  #前処理の適用

おそらく、普通に書くとこうなると思うのですが、どうやらこの書き方だと非常に遅くなるようです。というのも、基本的にメインメモリ(RAM)上での作業となり、GPUのメモリ(VRAM)上での作業ではないため、メモリ速度(というか帯域?)に引っ張られ、かつCPUの処理速度にも引っ張られて遅くなるようです。

なので、今回私は、次の記事(Link)を参考にして、以下のようなコードにしました。

dataloader_highspeed
image = transforms.ToTensor()(np.load("image/data.npy")).to(device="cuda")
#npyファイルの読み込み、torch.tensorへの変換、cudaへの転送
#torchvision.io.read_image(image_path)でも可

image = nn.Sequential(
            transforms.~~~
        ).to(torch.device('cuda'))(image)
#前処理をnn.Sequentialで記述し、cudaに流して処理する

最初のコードとの違いとしては、画像データ自体をあらかじめnpyファイルにして置き、それを読み込む形式にしたことと、前処理をnn.Sequentialに流すようにしたことです。これによって、学習はじめのdataloader部分で時間をとられるということがほぼ無くなりました。

この、すべてをcudaに投げつけるという手法で十分な速度が出るようになったので、私は満足して切り上げましたが、より高速化する手段自体はまだあるそうです。

ほかには、NVIDIA提供のMerlin Data Loader(Link)を使用する方法があるそうですが、これはどうやらRAPIDSなるライブラリの一部?として提供されているようです。(Link)
ただ、OSがLinux必須(or wsl)だったので諦めました。

推論速度向上で幸せ

ほとんど次の記事の焼き直しです。

概論としては、Automatic Mixed Precision(AMP)を使って、fp32ではなくその半分のfp16(or bfloat16)で計算をさせる、というものです。(bf16はampere以降のGPUのみ対応(geforce rtxは3000番代以降))
学習時のコードはこのような形になります。(推論時もほぼ同様なので省略)

pytorch_inference
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = yourmodel()
model.to(device)
#modelの定義

train_dataset, validation_dataset = ~~
train_data_loader, validation_data_loader = ~~
#dataloaderの定義

optimizer = RAdam(model.parameters(), lr=learning_rate, weight_decay=1e-5)
criterion = MSELoss()
#optimizerとloss関数の定義

## 通常時との変更点
scaler = torch.GradScaler() # scalerの導入
torch.backends.cudnn.benchmark = True #ネットワーク計算の最適化?

for epoch in range(num_epoch):
        for idx, (image1, image2) in enumerate(train_data_loader):
            optimizer.zero_grad(set_to_none=True)
            image1, image2 = image1.to(device), image2.to(device)

            ##通常時との変更点
            with torch.autocast(device_type="cuda", dtype = torch.bfloat16):
            #このwith内部をbf16で計算してくれる。
                output = model(image1)
                loss = criterion(output, image2)
            scaler.scale(loss).backward() # loss.backward()に対応
            scaler.unscale_(optimizer) #gradient clippingのためにfp32に修正
            
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0, error_if_nonfinite=True)
            
            scaler.step(optimizer) #optimizer.step()に対応
            scaler.update()

        #modelの勾配計算をしない場合の推奨される処理
        model.eval()
        for param in model.parameters():
            param.grad = None

上記のコードによって、AMPを適用した形で学習を行うことが出来ます。
高速化という点のみならず、メモリ使用量の削減も行えるので、可能な限り導入することをお勧めします。
注意点として、通常はfp32で計算される勾配を、scalerでfp16などに直す際には、2**16で割る、という処理が入っています。このせいで勾配がnanやinfになる場合があるらしいです。その場合は、

scaler = torch.GradScaler(init_scale = 4096)

などとすれば解決する(場合がある)そうです。

pytorch工夫(その他)

精度向上のための工夫として、知ってよかったものを記載します。

deform conv2dはいいぞ

deformable convolutionというものがあります。これは、convolutionで使われる畳み込み計算において、そのカーネルの形状自体を学習させるというものです。

例えば、3x3の畳み込みを行う場合、畳み込みに使われるデータは、3x3の中心座標を(0,0)としたら、
(-1.-1)~(1,1)の9個になります。この、中心座標からのずれの値を学習可能なパラメータにすることで、畳み込みに使われる領域自体を変化させるのがdeform convです。
image.png(元論文から引用している解説記事から引用)

現在では、ver4までが提案されており、pytorchではver2まで公式実装されています。
非公式だとこのようなライブラリも存在します。

実際に使ったときに、信じられないくらいわかりにくかったので、サンプルコードを置いておきます。

deformconv2dのサンプルコード

普通にclassで定義する場合は以下の通り

deformconv2d_sample

class model(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.defconv_offset = nn.Conv2d(
            in_channels= 3,
            out_channels =18,
            kernel_size = 3,
            stride = 1,
            padding = 1
        )
        #畳み込みに使われるデータが中心からどれだけずれているか、を取得するための畳み込み
        #out_channelsは2*groups*(kernel_size**2)となる
        #offsetは二次元方向へのずれなので、変数としては2つ必要となる
        #kernel_sizeが9の場合、inputのH,W座標一つに対して、9個のoffset情報が必要になる。
        #ここのoffset情報は二つの変数で表現されるので、out_channelsは2*kernel_size**2となる
        #groups>1の場合、それぞれのgroupに対してoffsetが必要になるので、channelsにgroupsを掛ける

        self.mod = nn.Conv2d(
            in_channels= 3,
            out_channels =9,
            kernel_size = 3,
            stride = 1,
            padding = 1
        )
        #畳み込み計算時に、各座標ごとの重みを決めるための畳み込み(SEnetのexcitationみたいなやつ)
        #各座標に対して、畳み込みに使うfilter分の情報があればいいので、sizeは以下の通り
        #out_channels = offset_groups * kernel_height * kernel_width

        self.Dconv = torchvision.ops.DeformConv2d(
            in_channels = 3, 
            out_channels = 3,
            kernel_size = 3, 
            padding = 1,
            groups = 1
            )
        #メインとなるdeformconv2dの定義
        #通常のnn.Conv2dと定義の仕方は一緒
            
    def forward(self, X):
        X = self.Dconv(
            input = X, 
            offset = self.defconv_offset(X), 
            #mask = torch.sigmoid(self.mod(X))
        )
        #使う場合、input、offsetが必須(maskは任意)

少し変則的だが、計算としてのdeformconv2dだけ使う場合は以下の通り

deformconv2d_sample

class model(nn.Module):
    def __init__(self) -> None:
        super().__init__()

        self.deform_conv = nn.Conv2d(
            in_channels= 3,
            out_channels =3,
            kernel_size = 3,
            stride = 1,
            padding = 1
        )

        self.defconv_offset = nn.Conv2d(
            in_channels= 3,
            out_channels =18,
            kernel_size = 3,
            stride = 1,
            padding = 1
        )

        self.mod = nn.Conv2d(
            in_channels= 3,
            out_channels =9,
            kernel_size = 3,
            stride = 1,
            padding = 1
        )

    def forward(self, X):
        X = torchvision.ops.deform_conv2d(
            input = X, 
            offset = self.defconv_offset(X), 
            weight = self.deform_conv.weight, 
            #mask = torch.sigmoid(self.mod(X)),
            padding = (1, 1)
        )
        #基本的な使い方は同じ

conv2d(groups=)はいいぞ

nn.conv2dでgroupsを指定するのはとてもよきです。なぜかというと、inputのchannel方向に情報を分けて計算できたり、はたまたchannel方向にrepeatしたものを入れることで、実質的にアンサンブル学習になったりという利点があるからです。
どういうことかというと、通常時の畳み込みは、inputの全チャネルの情報を活用して行われますが、grouped convにすることで、チャネルを分割してそれぞれに対しての畳み込みをまとめて行うことが出来るためです。

例えば、inputが(1,3,64,64)とし、conv2dのout_channel=3とします。
この時どのような計算が行われるかというと、通常は、(3,64,64)のチャネルで畳み込まれて1チャネル目、別の(3,64,64)のチャネルで2チャネル目、さらに別の(3,64,64)のチャネルで3チャネル目が計算されてまとめられ、出力されます。
この時、出力された3チャネルには、inputの3チャネルすべての情報が混じっていることになります。
ここで、groups=3とした際に何が起こるかというと、まずinputは(1,64,64)が3枚という形に分割され、畳み込みに使うチャネルも(1,64,64)が3枚となって、それぞれが独立して畳み込まれることになります。そのため、outputの3チャネルは、すべてinputの3チャネルの中の一つの情報しか持っていないことになります。
(例えばinputがRGBだとしたら、outputはR情報由来、G情報由来、B情報由来と完全に分離される)

チャンネルごとに持っている情報を分けつつ処理できる、という利点と、inputをchannel方向にrepeatして入力することで、完全別経路で並列に処理することが出来るという利点があります。

これによって、channelそれぞれの情報に対しての最適な処理をできたり、同じデータに対しての別処理をすることで、多様なデータを取り出せたり、ということになるわけです。

なので、雑に精度を上げられるためおすすめしたいです。(ResNeXtあたりで使われてたはず)
(grouped convの解説記事リンク)
※ただし、ONNXに変換して使う場合は速度がかなり遅くなります。

optimizer紹介

二つ紹介します。自分では理解できてないので、解説記事リンクを貼っておきます。

  • RAdam
    学習率をあまりいじらなくても、いい成果をもたらしてくれる感

  • SAM(実装Link)
    局所的な最適解ではなく、大域的最適解に行こうとする感じ
    (今回のタスクだと、収束までの時間がかかりすぎたため却下した。trainのpsnr<valのpsnr、というまま学習が続いていたので、かなり正則化がかかってそう)

SAMはともかく、RAdamはpytorchの公式実装もあるので、ぜひ使ってみてください。

汎化性能向上のためのAWP

AWP(Adversarial Weight Perturbation)は、非常にざっくりいうとモデルのパラメーターに摂動を与えることで、モデルの頑健性を上げるものです。
詳しい解説は次のリンク先にあります。(解説リンク)

ついでに、以下が今回作成したコードの参照元で、その下に実際に今回使ってたコードとその使用例をまとめておきました。
実装コード参照元

今回のために改良したAWPコードと使用例
AWP_sample
class AWP:
    def __init__(
        self,
        model,
        criterion,
        optimizer,
        adv_param="weight",
        adv_lr=1.0,
        adv_eps=0.01,
    ):
        self.model = model
        self.criterion = criterion
        self.optimizer = optimizer
        self.adv_param = adv_param
        self.adv_lr = adv_lr
        self.adv_eps = adv_eps
        self.backup = {}
        self.backup_eps = {}

    def attack_backward_clasiffer(self, inputs, label):
        with torch.cuda.amp.autocast():
            self._save()
            self._attack_step()  # モデルを近傍の悪い方へ改変
            y_preds = self.model(inputs)
            adv_loss = self.criterion(y_preds.view(-1, 1), label.view(-1, 1))
            mask = label.view(-1, 1) != -1
            adv_loss = torch.masked_select(adv_loss, mask).mean()
            self.optimizer.zero_grad()
        return adv_loss
    
    def attack_backward_CNN(self, inputs, y_ans):
        with torch.cuda.amp.autocast():
            self._save()
            self._attack_step()  # モデルを近傍の悪い方へ改変
            y_preds = self.model(inputs)
            adv_loss = self.criterion(y_preds, y_ans)
            self.optimizer.zero_grad()
        return adv_loss

    def _attack_step(self):
        e = 1e-6
        for name, param in self.model.named_parameters():
            if (
                param.requires_grad
                and param.grad is not None
                and self.adv_param in name
            ):
                norm1 = torch.norm(param.grad)
                norm2 = torch.norm(param.data.detach())
                if norm1 != 0 and not torch.isnan(norm1):
                    r_at = self.adv_lr * param.grad / (norm1 + e) * (norm2 + e)
                    param.data.add_(r_at)
                    param.data = torch.min(
                        torch.max(param.data, self.backup_eps[name][0]),
                        self.backup_eps[name][1],
                    )
                # param.data.clamp_(*self.backup_eps[name])

    def _save(self):
        for name, param in self.model.named_parameters():
            if (
                param.requires_grad
                and param.grad is not None
                and self.adv_param in name
            ):
                if name not in self.backup:
                    self.backup[name] = param.data.clone()
                    grad_eps = self.adv_eps * param.abs().detach()
                    self.backup_eps[name] = (
                        self.backup[name] - grad_eps,
                        self.backup[name] + grad_eps,
                    )

    def _restore(self):
        for name, param in self.model.named_parameters():
            if name in self.backup:
                param.data = self.backup[name]
        self.backup = {}
        self.backup_eps = {}
AWP_samplecode
## 通常時との変更点
scaler = torch.GradScaler() # scalerの導入
torch.backends.cudnn.benchmark = True #ネットワーク計算の最適化?
awp = AWP(model, criterion, optimizer, adv_lr=AWP_lr) #AWPインスタンス生成

for epoch in range(num_epoch):
        for idx, (image1, image2) in enumerate(train_data_loader):
            optimizer.zero_grad(set_to_none=True)
            image1, image2 = image1.to(device), image2.to(device)

            ##通常時との変更点
            with torch.autocast(device_type="cuda", dtype = torch.bfloat16):
            #このwith内部をbf16で計算してくれる。
                output = model(image1)
                loss = criterion(output, image2)
            scaler.scale(loss).backward() # loss.backward()に対応

            #追加部分
            loss = awp.attack_backward_CNN(low_resolution_image, high_resolution_image)
            scaler.scale(loss).backward()
            awp._restore()
            #追加部分終わり
            
            scaler.unscale_(optimizer) #gradient clippingのためにfp32に修正
            
            nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0, error_if_nonfinite=True)
            
            scaler.step(optimizer) #optimizer.step()に対応
            scaler.update()

(余談ですが、デフォルトのadv_lrが1と非常に高くなっていますが、これはおそらくある程度学習が進んだ段階での適用を想定しているかつ、勾配も正規化されたうえで学習率と掛けられることからこのような値になっていると思われます。
実際、ある程度収束した段階であれば、adv_lr=1でもvalidationに対しての精度は向上しましたし、反対に学習の初期から導入すると全然精度が安定しなくて苦労しました。)

pytorch関連エラーとその対処

ここからは、怨嗟交えつつのエラー対処の話をしていきます。

backwardでのerror

今回、lossを少し加工して、次のような関数を作成しました。

def loss
~~
tensor = torch.tensor(~)
tensor *= torch.func(tensor)
return tensor

すると、loss.backward()でerrorを吐き、以下のような文章が出力されました。
gradient computation has been modified by an inplace operation ~

これはどうやら a = f(a)のような書き方をすると発生するようです。
なので、以下のような修正を行ったところ解決しました。

def loss
~~
tensor = torch.tensor(~)
tensor_test = torch.func(tensor) * tensor
return tensor_test

pytorchのtensorは、tensor自体にgladの情報がセットで入っているそうで、.backwardすることで勾配を計算してくれるそうです。
ここからは推測ですが、おそらく自分で自分を書き換えるような形になっていることで、勾配計算の部分にバグが生じているのだと思います。

解説記事リンク

悪質だと思うのは、errorの出力場所が微妙に異なるので、対処が面倒くさいことです。
torchでのtensor操作の関数の書き方には気を付けよう。

ONNXへのconvertに対しての怨嗟

今回、コンペの提出形式がONNXモデルを活用した形だったため、
torch.onnx.exportを用いてonnxモデルへの変換を行いました。
そこで問題となったのが、torchでは定義されているけど、onnxでは定義されていない関数の変換です。

例えば、torch.rot90(tensor, i, dim=[2,3])という、配列を回転させる関数がありますが、これをforwardに入れているようなモデルは変換出来ません。
なぜなら、rot90に対して、ONNXへの変換方法が定義されていないためです。
なのでONNXへの変換方法が定義されている別の方法でこの操作を再現する必要があります。
(rot90については、transpose(flip(a))と等価なので、解決)

今回一番苦労したのは、deformconv2dの変換です。

悪戦苦闘の文字化記録

というのも、ONNXには、本体のバージョンのほかに、追加でサポートするoperatorが存在します。(opset version)
deform convはver19で実装だったが、今回のコンペではver17と指定されていました。そのため、ネットの海から、変換のためのライブラリを見つけてきましたが、これがなぜかうまく反応してくれませんでした。
githubのissueページを見ると、matrixのサイズを取得できない、みたいな話とその修正コードがあるので、それを見て書き換えることで、onnxへのexportはできたんですが、なぜか画像サイズの変動に対応できませんでした。
よくよく調べてみると、torch.onnx.exportではtorch.jit.scriptが標準で使われるそうですが、それでうまくいかない場合はtorch.jit.traceベースでの変換になるらしく、そうするとinput shapeの形状変動に対応できないそうです。
変換ライブラリのコードを詳しく見てみると、inputとして受け取っているtensorが、なぜかis_tensorでTrueを返すようになっていないらしく、そこでうまくいっていないことが判明しました。(ついでにissueのサイズエラーもこいつが原因)
結局こいつの修正方法はわからずで諦めました。

そんなわけで、結局変換がうまくいかず、deformconv2dの使用はあきらめました。 opset versionに縛られない世界に行きたい。

画像filter適用

上記onnxへの変換に関連した話として、画像へのfilter処理が挙げられます。
torchvisionには、画像にgaussian_filterを掛ける関数として、torchvision.transforms.GaussianBlur(kernel_size, sigma=(0.1, 2.0))
が存在しますが、これはONNXへの変換が出来ません。なので、自前で用意する必要がありました。
コードとしては以下の通り(Real-ESRGANの再現のために作ったコードなので、ちょっと余計なものが多々ある)

make_gaussian_filter
def calculate_gaussian_kernel(
    kernel_size: int, 
    sigma: list[float] = [1.0, 1.0],
    angle: float = 0,
    beta: float = 1,
    cutoff: float = 1.0,
    device: str = "cuda:0",
    mode: str = "gaussian"
    ) -> torch.Tensor:
    """
    一般Gaussian カーネルを計算します。
    modeとして、gaussian(対称、非対称)、plateu、sincの4つを実装
    """
    import math
    import torch
    import scipy
    angle = math.radians(angle)
    cutoff = math.radians(cutoff)
    size = kernel_size // 2

    rotation_matrix = torch.tensor([[math.cos(angle),-math.sin(angle)],[math.sin(angle), math.cos(angle)]], device = device)
    sigma_matrix = torch.tensor([[sigma[0]**2,0],[0, sigma[1]**2]], device = device)

    variance_matrix = torch.linalg.inv(rotation_matrix@sigma_matrix@(torch.t(rotation_matrix)))
    #print(variance_matrix)

    x = torch.arange(-size, size + 1, device=device).float()
    y = torch.arange(-size, size + 1, device=device).float()
    xx, yy = torch.meshgrid(x, y)
    
    
    if mode == "gaussian":
        matmul_matrix = (variance_matrix[0,0]*xx**2 + variance_matrix[0,1]*xx*yy + variance_matrix[1,0]*xx*yy + variance_matrix[1,1]*yy**2)**beta
        kernel = torch.exp(-0.5 * matmul_matrix)
        kernel /= kernel.sum()
    elif mode == "plateu":
        matmul_matrix = (variance_matrix[0,0]*xx**2 + variance_matrix[0,1]*xx*yy + variance_matrix[1,0]*xx*yy + variance_matrix[1,1]*yy**2)**beta
        kernel = 1/(1+matmul_matrix)
        kernel /= kernel.sum()
    elif mode == "sinc":
        r = torch.sqrt(xx**2 + yy**2)
        #print(r)
        kernel = torch.special.sinc(2*cutoff*r)
        #kernel = cutoff * scipy.special.j1(cutoff * r) / (2 * math.pi * r)
        kernel /= kernel.sum()
    else:
        print("mode is not correct.Please choose gaussian, plateu, or sinc")
        return None

    return kernel

得られたkernelを以下のように適用する

apply_gaussian_filter
dummy  = torch.randn(1,3,128,128)

kernel = calculate_gaussian_kernel(
    kernel_size=3, 
    sigma = [1.0, 1.0],
    mode = "gaussian",
    )
kernel = kernel.unsqueeze(0).unsqueeze(0) #形状を(3,3)から(1,1,3,3)に変更
kernel = kernel.repeat(3,1,1,1) #形状を(1,1,3,3)から(3,1,3,3)に変更

applied_dummy = torch.nn.functional.conv2d(input=dummy, weight=kernel, groups=3)
#weightは、(out_channels, input_channels/groups, kernel_h, kernel_w)となる

このようにして、画像にgaussian_blurを掛けることが出来ます。
注意点として、conv2dを使ってfilterを掛けているため、groupsをinput_channelsと同じにしないといけません。
さもないと、それぞれのchannelにfilterを掛けたものではなく、すべてのchannelにfilterを掛けて、足し合わせたものが返ってくるので要注意です。

pytorch-GAN関連

ここからは、GAN実装関連の話をしていきます。
GAN自体の説明はしません。おざっぱに、出力と識別のネットワークを別で作って、それらが競い合うように学習をする、というようなものです。
一般的?なGANの場合は、generatorの出力した画像をdiscriminatorに入れて、本物か偽物かのラベルを出力させ、generatorはそのラベルを1に、discriminatorはそのラベルを0にするように学習を進めていきます。
超解像の場合は、discriminatorは高解像度画像で1、generator出力画像で0を出力するように学習し、generatorはdiscriminatorに出力画像を入れたときに1になるように学習します。

GAN基本実装コード

まずは基本実装コードです。(with autocast)

GAN-基本コードwith autocast
#modelの定義
D = Discriminator()
G = Generator()
D.to(device)
G.to(device)

#dataloader定義
train_dataset, validation_dataset = ~
train_data_loader ,validation_data_loader = ~ 

#optimizer定義(二つ必要)
D_optim = RAdam(D.parameters(), lr = D_lr, weight_decay = 1e-5)
G_optim = RAdam(G.parameters(), lr = G_lr, weight_decay = 1e-5)

#学習時使うlabel定義(安定化のためラベル平滑化)
realLabel = 0.9*torch.ones(batch_size, 1).cuda()
fakeLabel = 0.1*torch.ones(batch_size, 1).cuda()
realLabel_val = 0.9*torch.ones(1, 1).cuda()
fakeLabel_val = 0.1*torch.ones(1, 1).cuda()

#loss定義
BCE = torch.nn.BCEWithLogitsLoss() #discriminatorの出力用のloss
PerceptualLoss = RegnetPerceptualLoss() #perceptual loss(後述)
criteriation = torch.nn.MSELoss() #画像の二乗誤差用のloss

#lossの係数定義
GAN_loss = 0.1
perce_loss = 0.1
MSE_loss = 1

#高速化のための用意
scaler = torch.GradScaler()
torch.backends.cudnn.benchmark = True

for epoch in range(num_epoch):
        #学習modeへのセット
        D.train()
        G.train()
        for idx, (low_resolution_image, high_resolution_image ) in enumerate(train_data_loader):
                #optimizerの初期化と画像のdev
                D_optim.zero_grad(set_to_none=True)
                G_optim.zero_grad(set_to_none=True)
                low_resolution_image = low_resolution_image.to(device)
                high_resolution_image = high_resolution_image.to(device)
                #モデルの勾配情報初期化
                D.zero_grad()
                with torch.autocast(device_type = "cuda", dtype = torch.bfloat16):
                        #generatorに高解像度画像を作らせる
                        fakeFrame = G(low_resolution_image)

                        #discriminatorにラベルを出力させる
                        DReal = D(high_resolution_image)
                        DFake = D(fakeFrame)

                        #generator画像を偽物、高解像度画像を本物としてい出力しているかのlossを計算
                        D_loss = (BCE(DFake, fakeLabel) + BCE(DReal, realLabel)) / 2
                #discriminatorに関してparameter update
                scaler.scale(D_loss).backward(retain_graph=True) #おそらくretain_graphを入れないとerror?
                scaler.step(D_optim)
                #D_loss.backward(retain_graph=True)
                #D_optim.step()

                #generatorの勾配初期化
                G.zero_grad()
                with torch.autocast(device_type = "cuda", dtype = torch.bfloat16):
                        #generatorの出力した画像が、discriminatorに本物として認識されているかのloss
                        G_label_loss = (BCE(DFake.detach().clone(), realLabel)
                        #DFakeをDFake.detach().clone()にしないと、backwardでerrorを吐く

                        #lossの合計
                        G_loss = MSE_loss * criteriation(fakeFrame, high_resolution_image) + GAN_loss * G_label_loss + perce_loss * PerceptualLoss(fakeFrame, high_resolution_image)
                #generatorのparameter update                    
                scaler.scale(G_loss).backward()
                scaler.step(G_optim)
                scaler.update()
                #G_loss.backward()
                #G_optim.step()

基本的なコードとしては上記のようになります。
ただし、backwardを2回呼ぶ関係で、G_lossの計算の時に、detach().clone()を入れる(+D.optimのupdateでretain_graph=Trueを入れる)必要があるみたいです。
どういう理屈なのかはいまいち分かってません。計算グラフを切ってcloneすることで、勾配情報を含まないscalerとして取り出している、ということなのかなと勝手に思っています。それによって、discriminatorで一度計算したbackwardと被らないようにしているのかなーと。

基本実装コード解説ーperceptual lossについてー

perceptual lossとは、出力された画像と高解像度画像を別の訓練済みモデル(vggなど)に通して、その特徴マップ同士の誤算を計算したものです。
これによって、全体の雰囲気を近しいものにできる?ようです。

(どこで見たか忘れましたが、分類タスクモデルよりsematnic segmentationモデルのほうがいい見たいな話もあるそうです。)

以下、参考サイトと使ったサンプルコード

サンプル

https://qiita.com/AokiMasataka/items/bfb5e338079f01bfc996

VGG loss
class VGGPerceptualLoss(torch.nn.Module):
    def __init__(self):
        super(VGGPerceptualLoss, self).__init__()
        blocks = []
        blocks.append(models.vgg16(pretrained=True).features[:4].eval())
        blocks.append(models.vgg16(pretrained=True).features[4:9].eval())
        blocks.append(models.vgg16(pretrained=True).features[9:16].eval())
        blocks.append(models.vgg16(pretrained=True).features[16:23].eval())
        blocks.append(models.vgg16(pretrained=True).features[23:30].eval())
        for bl in blocks:
            for p in bl:
                p.requires_grad = False
        self.blocks = torch.nn.ModuleList(blocks).cuda()
        self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1), requires_grad=False).cuda()
        self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1), requires_grad=False).cuda()

    def forward(self, fakeFrame, frameY):
        fakeFrame = (fakeFrame - self.mean) / self.std
        frameY = (frameY - self.mean) / self.std
        loss = 0.0
        x = fakeFrame
        y = frameY
        for block in self.blocks:
            x = block(x)
            y = block(y)
            loss += torch.nn.functional.l1_loss(x, y)
        return loss

上のコードだとメモリ使用量が多いので、軽いpretrain modelを使ってやっていた。
以下その参考コード

Regnet loss
class RegnetPerceptualLoss(torch.nn.Module):
    def __init__(self):
        super(RegnetPerceptualLoss, self).__init__()
        blocks = []
        blocks.append(torchvision.models.regnet_x_800mf().stem.eval())
        blocks.append(torchvision.models.regnet_x_800mf().trunk_output.block1.eval())
        blocks.append(torchvision.models.regnet_x_800mf().trunk_output.block2.eval())
        blocks.append(torchvision.models.regnet_x_800mf().trunk_output.block3.eval())
        blocks.append(torchvision.models.regnet_x_800mf().trunk_output.block4.eval())
        for bl in blocks:
            for p in bl:
                p.requires_grad = False
        self.blocks = torch.nn.ModuleList(blocks).cuda()
        self.mean = torch.nn.Parameter(torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1), requires_grad=False).cuda()
        self.std = torch.nn.Parameter(torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1), requires_grad=False).cuda()

    def forward(self, fakeFrame, frameY):
        fakeFrame = (fakeFrame - self.mean) / self.std
        frameY = (frameY - self.mean) / self.std
        loss = 0.0
        x = fakeFrame
        y = frameY
        for block in self.blocks:
            x = block(x)
            y = block(y)
            loss += torch.nn.functional.l1_loss(x, y)
        return loss

VGGはmodel自体のサイズが500 MBほどですが、Regnetは30MB程度なので、非常に軽いです。(精度面での検証は未実施)

但し、モデルをもう一つ使う関係上、メモリ使用量は爆増します。
rtx3060ではvgg16に耐えられなかったので、Regnetなる軽いモデルに切り替えてやってました。

Relativistic Discriminator

大雑把にいうと、discriminatorに、本物と偽物を比較して、より本物らしいと出力させるlossになっています。
image.png

これをすることで、GANの学習を安定化させることが出来るらしいです。

この部分のみのサンプルコード

discriminatorのlossと、G_label_lossを以下のコードで置き換える感じで使ってください。

relative label loss
BCE = torch.nn.BCEWithLogitsLoss() #discriminatorの出力用のloss

#discriminatorのloss
DReal = D(high_resolution_image)
DFake = D(fakeFrame)
DReal_delta = DReal - torch.mean(DFake, dim=0, keepdim=True).repeat(batch_size,1,1,1)
DReal_delta = torch.mean(DReal_delta.reshape(batch_size, -1), dim = 1, keepdim=True)
DReal_delta = F.sigmoid(DReal_delta)
DFake_delta = DFake - torch.mean(DReal, dim=0, keepdim=True).repeat(batch_size,1,1,1)
DFake_delta = torch.mean(DFake_delta.reshape(batch_size, -1), dim = 1, keepdim=True)
DFake_delta = F.sigmoid(DFake_delta)g
D_loss = (BCE(DFake_delta, fakeLabel) + BCE(DReal_delta, realLabel)) / 2

#generatorのlabel loss
G_label_loss = (BCE(DFake_delta.detach().clone(), realLabel) + BCE(DReal_delta.detach().clone(), fakeLabel)) / 2

Locally discriminative learning(LDL)

見つけたいい感じのloss改善案です。正直よくわかってないので、サンプルコードだけ貼っておきます。
内部で、EMA(Exponential Moving Average)なるものを使っています。gpt4oくんに聞いたら出てきたものをまんま使ってます。

LDLサンプルコード
EMA-def
class EMA:
    def __init__(self, model, decay=0.999):
        self.model = model
        self.decay = decay
        self.shadow = {}
        self.backup = {}

        # Initialize the shadow dictionary with the model's initial parameters
        for name, param in model.named_parameters():
            if param.requires_grad:
                self.shadow[name] = param.data.clone()

    def update(self):
        # Update the shadow weights with the current model parameters
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                # EMA calculation
                new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
                self.shadow[name] = new_average.clone()

    def apply_shadow(self):
        # Backup the current model parameters and replace them with the shadow weights
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                self.backup[name] = param.data.clone()
                param.data = self.shadow[name]

    def restore(self):
        # Restore the original model parameters from the backup
        for name, param in self.model.named_parameters():
            if param.requires_grad:
                param.data = self.backup[name]
        self.backup = {}

論文中で実装されていたrefine map用の関数

make-map
def make_map(image:torch.tensor) -> torch.tensor:
    """
    imageに対して、7*7のカーネルを設定し、それぞれの領域で分散を計算する。
    その後、全体に対しての分散を1/5乗した値と掛け合わせて、マップとして出力する
    """
    image_h, image_w = image.size(2), image.size(3)
    m = torch.nn.Unfold(kernel_size=7, dilation=1, padding=3, stride=1)
    image_1, image_2, image_3 = image[:,0,:,:,].unsqueeze(0),image[:,1,:,:,].unsqueeze(0),image[:,2,:,:,].unsqueeze(0)
    local_var_1 = m(image_1).var(dim=1, keepdim=True).reshape(-1,1,image_h,image_w)
    local_var_2 = m(image_2).var(dim=1, keepdim=True).reshape(-1,1,image_h,image_w)
    local_var_3 = m(image_3).var(dim=1, keepdim=True).reshape(-1,1,image_h,image_w)#1,1,h,w
    local_var_11 = local_var_1 *torch.var(local_var_1)
    local_var_22 = local_var_2 *torch.var(local_var_2)
    local_var_33 = local_var_3 *torch.var(local_var_3)
    local_var = torch.cat([local_var_11,local_var_22,local_var_33],dim=1)#1,3,h,w
    return local_var
LDL_sample
#変数定義
perloss_cof = 1.0
GAN_cof = 0.1
L1_cof = 1
arti_coef = 1

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#model定義
D = Discriminator()
G = Generator()
D.to(device)
G.to(device)

#EMA定義
ema = EMA(G, 0.999)

#dataloader定義
train_dataset, validation_dataset = get_dataset()
train_data_loader, validation_data_loader = ~

#optimizer定義
D_optim = RAdam(D.parameters(), lr = D_lr, weight_decay = 1e-5)
G_optim = RAdam(G.parameters(), lr = G_lr, weight_decay = 1e-5)

#ganラベル(平滑)
realLabel = 0.9*torch.ones(batch_size, 1).cuda()
fakeLabel = 0.1*torch.ones(batch_size, 1).cuda()
realLabel_val = 0.9*torch.ones(1, 1).cuda()
fakeLabel_val = 0.1*torch.ones(1, 1).cuda()

#loss定義
BCE = torch.nn.BCEWithLogitsLoss()
PerceptualLoss = RegnetPerceptualLoss()
criteriation = torch.nn.MSELoss()

#高速化あれこれ
scaler = torch.GradScaler()
torch.backends.cudnn.benchmark = True

for epoch in range(num_epoch):
        D.train()
        G.train()
        train_loss = 0.0
        validation_loss = 0.0
        train_psnr = 0.0
        validation_psnr = 0.0
        for idx, (low_resolution_image, high_resolution_image ) in enumerate(train_data_loader):
                D_optim.zero_grad(set_to_none=True)
                G_optim.zero_grad(set_to_none=True)
                low_resolution_image = low_resolution_image.to(device)
                high_resolution_image = high_resolution_image.to(device)
                with torch.autocast(device_type = "cuda", dtype = torch.bfloat16):
                        fakeFrame = G(low_resolution_image)#Gに作らせた高解像度画像
                        D.zero_grad()
                        DReal = D(high_resolution_image)
                        DFake = D(fakeFrame)
                        DReal_delta = DReal - torch.mean(DFake, dim=0, keepdim=True).repeat(batch_size,1,1,1)
                        DReal_delta = torch.mean(DReal_delta.reshape(batch_size, -1), dim = 1, keepdim=True)
                        DReal_delta = F.sigmoid(DReal_delta)
                        DFake_delta = DFake - torch.mean(DReal, dim=0, keepdim=True).repeat(batch_size,1,1,1)
                        DFake_delta = torch.mean(DFake_delta.reshape(batch_size, -1), dim = 1, keepdim=True)
                        DFake_delta = F.sigmoid(DFake_delta)

                        D_loss = (BCE(DFake_delta, fakeLabel) + BCE(DReal_delta, realLabel)) / 2
                        #GANにおける相対lossを計算する関数。
                        
                        
                #D_loss.backward(retain_graph=True)
                #D_optim.step()
                scaler.scale(D_loss).backward(retain_graph=True)
                scaler.step(D_optim)

                G_optim.zero_grad(set_to_none=True)
                G.zero_grad()

                with torch.autocast(device_type = "cuda", dtype = torch.bfloat16):
                        ema.apply_shadow()#移動平均モデルに置き換え
                        fakeFrame_ema = G(low_resolution_image)#ema modelでの生成画像
                        refine_map = make_map(high_resolution_image - fakeFrame)
                        #論文で定義されていた、それぞれの座標で7x7のカーネルの分散を求め、5乗根を取ったもの
                        true_map = torch.abs(high_resolution_image - fakeFrame) > torch.abs(high_resolution_image - fakeFrame_ema)
                        #上で求めたmapで、高解像度画像と出力画像の差の絶対値が、emaモデルより大きい場所を取得
                        loss_arti = torch.norm(true_map*refine_map*(high_resolution_image - fakeFrame))
                        #先ほど求めた、条件を満たす位置だけでの誤算を計算
                        G_label_loss = (BCE(DFake_delta.detach().clone(), realLabel) + BCE(DReal_delta.detach().clone(), fakeLabel)) / 2
                        G_loss = criteriation(fakeFrame, high_resolution_image)*L1_cof + G_label_loss*GAN_cof + PerceptualLoss(fakeFrame, high_resolution_image)*perloss_cof
                                  
                        G_loss += loss_arti * arti_coef
                        
                scaler.scale(G_loss).backward()
                #G_loss.backward()
                #G_optim.step()
                ema.restore()#保持していたモデルパラメーターの初期化
                        
                scaler.step(G_optim)
                ema.update()
                scaler.update()
# GAN参考資料 GANのトレーニングに役立つヒント集があったので共有しておきます。

調べたモデル

ここからは、調べたモデルや使ったモデルの話をしていきます。
先に参考サイトのリンクだけまとめておきます。

調べたモデル-参考リンクまとめ

1,GCP無料枠で作ったモデル

2,別の超解像コンペ参加者の記録

3,超解像surveyまとめ記事(後半が存在しない)

4,ICCV論文紹介

5,別の超解像コンペ参加者の記録

6,CoatNet解説記事

7, ESCPN解説

8, 超解像歴史(1~3)(ここでは3のみ)

9,4CH-R解説

調べたモデル 実際に使ったやつ編

使ったやつ編です。

ESCPN

シンプルなCNNモデルです。最後にPixelShuffleをすることで拡大画像を得るモデルになっています。
(今回提供されたサンプルコードのモデルであり、結局超えられなかったベースラインモデルになります)

ESCPN(サンプルから微修正済み)
ESPCN
class ESPCN4x(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.scale = 4
        self.conv_1 = nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5, padding=2)
        nn.init.kaiming_normal_(self.conv_1.weight)
        nn.init.zeros_(self.conv_1.bias)

        self.act = nn.GELU()

        self.conv_2 = nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=1)
        nn.init.kaiming_normal_(self.conv_2.weight)
        nn.init.zeros_(self.conv_2.bias)

        self.conv_3 = nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, padding=1)
        nn.init.kaiming_normal_(self.conv_3.weight)
        nn.init.zeros_(self.conv_3.bias)

        self.conv_4 = nn.Conv2d(in_channels=32, out_channels=(3 * self.scale * self.scale), kernel_size=3, padding=1)
        nn.init.kaiming_normal_(self.conv_4.weight)
        nn.init.zeros_(self.conv_4.bias)

        self.pixel_shuffle = nn.PixelShuffle(self.scale)

    def forward(self, X_in: tensor) -> tensor:
        X = self.act(self.conv_1(X_in))#N,64,H,W
        X = self.act(self.conv_2(X))#N,32,H,W
        X = self.act(self.conv_3(X))#N,32,H,W
        X = self.conv_4(X)#N,48,H,W
        X = self.pixel_shuffle(X)#N,3,4H,4W
        X_out = clip(X, 0.0, 1.0)
        return X_out

このモデルの利点は、何より軽いくせにそこそこ精度が高いということです。
ただ、pixelshuffle後に調整するconvなどはないので、そこは改善のよりがあるのかもしれないです。
大体PSNRが28.3くらいでした

4CH-R

input画像を、0, 90, 180, 270度回転させてCNNに通し、その後0, -90, -180, -270度回転させて加算させるモデルです。
CNNのfilterには方向依存性があることから、処理を別経路にすべきなのではないか、ということで提案されていました。(9,4CHR解説参照)
PSNRは28.7くらいでした。(ベストスコア、ただし処理時間オーバー)

4CHRコード
4CHR model
class RCH4_custom(nn.Module):
    """
    4倍解像度の画像を生成するモデル
    1層目は64、2層目は32、1層目と2層目のカーネルサイズは5、3層目は3

    作業としては、1チャネルにして、回して3層
    1層目は64でカーネル5、2層目は32でカーネル5、3層目は16でカーネル3
    これをそれぞれチャネル方向に和をとってから、回転を元に戻して、平均をとる。
    """
    def __init__(self) -> None:
        super().__init__()
        self.scale = 4

        self.process_000 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=5, padding=2),
            #N*C,64,H,W
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, padding=2),
            #N*C,32,H,W
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, padding=1),
            #N*C,16,H,W
            nn.ReLU(),
            nn.PixelShuffle(self.scale)
            #N*C,1,4*H,4*W
        )

        self.process_090 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.PixelShuffle(self.scale)
        )

        self.process_180 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.PixelShuffle(self.scale)
        )

        self.process_270 = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=64, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=64, out_channels=32, kernel_size=5, padding=2),
            nn.ReLU(),
            nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.PixelShuffle(self.scale)
        )

    def forward(self, X_in: tensor) -> tensor:
        X = X_in.reshape(-1, 1, X_in.shape[-2], X_in.shape[-1])
        #N*C,1,H,W
        X_090 = torch.permute(torch.flip(X,dims=[3]),(0,1,3,2))
        X_180 = torch.permute(torch.flip(X_090,dims=[3]),(0,1,3,2))
        X_270 = torch.permute(torch.flip(X_180,dims=[3]),(0,1,3,2))

        X_000 = self.process_000(X)
        X_090 = self.process_090(X_090)
        X_180 = self.process_180(X_180)
        X_270 = self.process_270(X_270)
        #N*C,1,4*H,4*W

        X_090 = torch.permute(torch.flip(X_090,dims=[2]),(0,1,3,2))
        X_180 = torch.permute(torch.flip(X_180,dims=[2]),(0,1,3,2))
        X_180 = torch.permute(torch.flip(X_180,dims=[2]),(0,1,3,2))
        X_270 = torch.permute(torch.flip(X_270,dims=[3]),(0,1,3,2))

        X_out = (X_000 + X_090 + X_180 + X_270)/4
        X_out = X_out.reshape(-1, 3, X_out.shape[-2], X_out.shape[-1])
        X_out = clip(X_out, 0.0, 1.0)
        return X_out

Real ESRGAN データ拡張

モデル自体はESCPNとUnetDiscriminator(著者実装コピペ)ですが、データ拡張のコードを自作しました。
このモデルは、ESRGANをベースとして、画像の前処理をより現実世界の劣化に近づけようというものです。
解説記事はこれです。

これをもとにして、dataloaderの前処理として書いたコードを貼っておきます。

Real ESRGAN前処理

jpeg圧縮のコード

degrade_image_with_jpg
def degrade_image_with_jpg(
    image:torch.tensor,
    quality:int = 80
    )->torch.tensor:

    from io import BytesIO
    from PIL import Image
    image_PIL = transforms.functional.to_pil_image(image)
    buffer = BytesIO()
    image_PIL.save(
        buffer,
        format='JPEG',
        quality=10)

    image_PIL = Image.open(buffer)
    image = transforms.functional.to_tensor(image_PIL)
    return image
Real ESRGAN preprocess
# データセット定義
class DataSetBase(data.Dataset, ABC):
    def __init__(self, image_path: Path):
        self.images = list(image_path.iterdir())
        self.max_num_sample = len(self.images)
        
    def __len__(self) -> int:
        return self.max_num_sample
    
    @abstractmethod
    def get_low_resolution_image(self, image: Image, path: Path)-> Image:
        pass
    
    def preprocess_high_resolution_image(self, image: Image) -> Image:
        return image
    
    def __getitem__(self, index) -> Tuple[Tensor, Tensor]:
        image_path = self.images[index % len(self.images)]
        high_resolution_image = self.preprocess_high_resolution_image(transforms.ToTensor()(np.load(image_path)).to(dtype = torch.float32, device = "cuda"))
        low_resolution_image = self.get_low_resolution_image(high_resolution_image, image_path)
        return low_resolution_image, high_resolution_image

class TrainDataSet_C(DataSetBase):
    def __init__(self, image_path: Path, num_image_per_epoch: int = 2000):
        super().__init__(image_path)
        self.max_num_sample = num_image_per_epoch

    def get_low_resolution_image(self, image: Tensor, path: Path)-> Tensor:
        import random
        import math
        device = 'cuda:0'
        image = image.to(device=device)
        image_h, image_w = image.size(dim=-2), image.size(dim=-1)
        kernel_size = random.choice([7,9,11,13,15,17,19,21])
        angle = random.uniform(-180,180)
        sigma = [random.uniform(0.2,3),random.uniform(0.2,3)]
        p = random.uniform(0,1)
        #最初のぼかし
        if p < 0.1:
            if kernel_size < 13:
                cutoff = random.uniform(60,180)
            else:
                cutoff = random.uniform(36,90)
            weight = calculate_gaussian_kernel(
                kernel_size=kernel_size, 
                cutoff= cutoff,
                mode= "sinc"
                )
            weight = weight.unsqueeze(0).unsqueeze(0).repeat(3,1,1,1)
            image = torch.nn.functional.conv2d(image, weight = weight,padding=kernel_size//2,groups = 3)
        else:
            p1 = random.uniform(0,1)
            if p1 < 0.7:
                weight = calculate_gaussian_kernel(
                    kernel_size=kernel_size, 
                    sigma=sigma,
                    angle = angle,
                    beta= 1,
                    #cutoff= cutoff,
                    mode= "gaussian"
                    )
                weight = weight.unsqueeze(0).unsqueeze(0).repeat(3,1,1,1)
                image = torch.nn.functional.conv2d(image, weight = weight,padding=kernel_size//2,groups = 3)
            elif p1 < 0.85:
                beta = random.uniform(0.5,4)
                weight = calculate_gaussian_kernel(
                    kernel_size=kernel_size, 
                    sigma=sigma,
                    angle = angle,
                    beta= beta,
                    #cutoff= cutoff,
                    mode= "gaussian"
                    )
                weight = weight.unsqueeze(0).unsqueeze(0).repeat(3,1,1,1)
                image = torch.nn.functional.conv2d(image, weight = weight,padding=kernel_size//2,groups = 3)
            else:
                beta = random.uniform(1,2)
                weight = calculate_gaussian_kernel(
                    kernel_size=kernel_size, 
                    sigma=sigma,
                    angle = angle,
                    beta= beta,
                    #cutoff= cutoff,
                    mode= "plateu"
                    )
                weight = weight.unsqueeze(0).unsqueeze(0).repeat(3,1,1,1)
                image = torch.nn.functional.conv2d(image, weight = weight,padding=kernel_size//2,groups = 3)
        #最初のresize
        p = random.uniform(0,1)
        interpolate = random.choice(["area", "bilinear", "bicubic"])
        #拡大:縮小:維持の確率:1st stageは「20%, 70%, 10%」[0.15,1.5]、2nd stageは「30%, 40%, 30%」
        if p < 0.2:
            scale_factor = random.uniform(1.0,1.5)
            p3 = random.choice([0,1,2])
            if p3 == 0:
                image = transforms.Resize((round(scale_factor * image.size(dim=-2)), round(scale_factor * image.size(dim=-1))), transforms.InterpolationMode.BICUBIC)(image)
            elif p3 == 1:
                image = transforms.Resize((round(scale_factor * image.size(dim=-2)), round(scale_factor * image.size(dim=-1))), transforms.InterpolationMode.NEAREST)(image)
            elif p3 == 2:
                image = transforms.Resize((round(scale_factor * image.size(dim=-2)), round(scale_factor * image.size(dim=-1))), transforms.InterpolationMode.BILINEAR)(image)
        elif p < 0.9:
            scale_factor = random.uniform(0.15,1.0)
            p3 = random.choice([0,1,2])
            if p3 == 0:
                image = transforms.Resize((round(scale_factor * image.size(dim=-2)), round(scale_factor * image.size(dim=-1))), transforms.InterpolationMode.BICUBIC)(image)
            elif p3 == 1:
                image = transforms.Resize((round(scale_factor * image.size(dim=-2)), round(scale_factor * image.size(dim=-1))), transforms.InterpolationMode.NEAREST)(image)
            elif p3 == 2:
                image = transforms.Resize((round(scale_factor * image.size(dim=-2)), round(scale_factor * image.size(dim=-1))), transforms.InterpolationMode.BILINEAR)(image)
        else:
            image = image
        #最初のNoise
        p = random.uniform(0,1)
        p_g = random.uniform(0,1)
        if p < 0.5:
            scale_factor = random.uniform(0.05,3)
            pass
            if p_g > 0.4:
                image_array_noise = cp.asarray(torch.mean(image,dim=0,keepdim=True))*255*scale_factor
                image_array = cp.random.poisson(image_array_noise)/(255*scale_factor)#1, 512,512
                image_delta = torch.tensor(image_array, device=device, dtype = torch.float32)-torch.mean(image,dim=0,keepdim=True, dtype = torch.float32)
                image += image_delta.repeat(3,1,1)
                image = torch.clip(image, 0.0, 1.0).to(device=device)
            else:
                image_array = cp.asarray(image)*255*scale_factor
                image_array = cp.random.poisson(image_array)/(255*scale_factor)
                image = torch.tensor(image_array, device=device, dtype = torch.float32)
                image = torch.clip(image, 0.0, 1.0).to(device=device)
        else:
            #グレーノイズが40 %
            sigma = random.uniform(1,30)
            if p_g < 0.4:
                image += torch.normal(mean = 0, std = sigma, size=(1,image.size(dim=-2), image.size(dim=-1)),device=device).repeat(3,1,1)/255
                image = torch.clip(image, 0.0, 1.0).to(device=device)
            else:#カラーノイズが60 %
                image += torch.normal(mean = 0, std = sigma, size=(3, image.size(dim=-2), image.size(dim=-1)),device=device)/255
                image = torch.clip(image, 0.0, 1.0).to(device=device)
                #image = transforms.GaussianNoise(mean=0.0, std=sigma)(image)
        #最初のJPEG
        quality = random.randint(30,95)
        image = degrade_image_with_jpg(image, quality = quality).to(device=device)
        #2ndのぼかし
        kernel_size = random.choice([7,9,11,13,15,17,19,21])
        angle = random.uniform(-180,180)
        sigma = [random.uniform(0.2,3),random.uniform(0.2,3)]
        p_s = random.uniform(0,1)
        if p_s < 0.2:
            pass
        else:
            p = random.uniform(0,1)
            if p < 0.1:
                if kernel_size < 13:
                    cutoff = random.uniform(60,180)
                else:
                    cutoff = random.uniform(36,90)
                weight = calculate_gaussian_kernel(
                    kernel_size=kernel_size, 
                    cutoff= cutoff,
                    mode= "sinc"
                    )
                weight = weight.unsqueeze(0).unsqueeze(0).repeat(3,1,1,1)
                image = torch.nn.functional.conv2d(image, weight = weight,padding=kernel_size//2,groups = 3)
            else:
                p1 = random.uniform(0,1)
                if p1 < 0.7:
                    weight = calculate_gaussian_kernel(
                        kernel_size=kernel_size, 
                        sigma=sigma,
                        angle = angle,
                        beta= 1,
                        #cutoff= cutoff,
                        mode= "gaussian"
                        )
                    weight = weight.unsqueeze(0).unsqueeze(0).repeat(3,1,1,1)
                    image = torch.nn.functional.conv2d(image, weight = weight,padding=kernel_size//2,groups = 3)
                elif p1 < 0.85:
                    beta = random.uniform(0.5,4)
                    weight = calculate_gaussian_kernel(
                        kernel_size=kernel_size, 
                        sigma=sigma,
                        angle = angle,
                        beta= beta,
                        #cutoff= cutoff,
                        mode= "gaussian"
                        )
                    weight = weight.unsqueeze(0).unsqueeze(0).repeat(3,1,1,1)
                    image = torch.nn.functional.conv2d(image, weight = weight,padding=kernel_size//2,groups = 3)
                else:
                    beta = random.uniform(1,2)
                    weight = calculate_gaussian_kernel(
                        kernel_size=kernel_size, 
                        sigma=sigma,
                        angle = angle,
                        beta= beta,
                        #cutoff= cutoff,
                        mode= "plateu"
                        )
                    weight = weight.unsqueeze(0).unsqueeze(0).repeat(3,1,1,1)
                    image = torch.nn.functional.conv2d(image, weight = weight,padding=kernel_size//2,groups = 3)
        #2ndのresize
        p = random.uniform(0,1)
        interpolate = random.choice(["area", "bilinear", "bicubic"])
        #拡大:縮小:維持の確率:1st stageは「20%, 70%, 10%」[0.15,1.5]、2nd stageは「30%, 40%, 30%」[0.3, 1.2]
        if p < 0.3:
            scale_factor = random.uniform(1.0,1.2)
            p3 = random.choice([0,1,2])
            if p3 == 0:
                image = transforms.Resize((round(scale_factor * image.size(dim=-2)), round(scale_factor * image.size(dim=-1))), transforms.InterpolationMode.BICUBIC)(image)
            elif p3 == 1:
                image = transforms.Resize((round(scale_factor * image.size(dim=-2)), round(scale_factor * image.size(dim=-1))), transforms.InterpolationMode.NEAREST)(image)
            elif p3 == 2:
                image = transforms.Resize((round(scale_factor * image.size(dim=-2)), round(scale_factor * image.size(dim=-1))), transforms.InterpolationMode.BILINEAR)(image)
        elif p < 0.7:
            scale_factor = random.uniform(0.3,1.0)
            p3 = random.choice([0,1,2])
            if p3 == 0:
                image = transforms.Resize((round(scale_factor * image.size(dim=-2)), round(scale_factor * image.size(dim=-1))), transforms.InterpolationMode.BICUBIC)(image)
            elif p3 == 1:
                image = transforms.Resize((round(scale_factor * image.size(dim=-2)), round(scale_factor * image.size(dim=-1))), transforms.InterpolationMode.NEAREST)(image)
            elif p3 == 2:
                image = transforms.Resize((round(scale_factor * image.size(dim=-2)), round(scale_factor * image.size(dim=-1))), transforms.InterpolationMode.BILINEAR)(image)
        else:
            image = image
        #2ndのNoise
        p = random.uniform(0,1)
        p_g = random.uniform(0,1)
        if p < 0.5:
            scale_factor = random.uniform(0.05,2.5)
            pass
            if p_g < 0.4:
                image_array_noise = cp.asarray(torch.mean(image,dim=0,keepdim=True))*255*scale_factor
                image_array = cp.random.poisson(image_array_noise)/(255*scale_factor)#1, 512,512
                image_delta = torch.tensor(image_array, device=device, dtype = torch.float32)-torch.mean(image,dim=0,keepdim=True, dtype = torch.float32)
                image += image_delta.repeat(3,1,1)
                image = torch.clip(image, 0.0, 1.0).to(device=device)
            else:
                image_array = cp.asarray(image)*255*scale_factor
                image_array = cp.random.poisson(image_array)/(255*scale_factor)
                image = torch.tensor(image_array, device=device, dtype = torch.float32)
                image = torch.clip(image, 0.0, 1.0).to(device=device)
        else:
            #グレーノイズが40 %
            sigma = random.uniform(1,30)
            if p_g < 0.4:
                image += torch.normal(mean = 0.0, std = sigma, size=(1, image.size(dim=-2), image.size(dim=-1)),device=device).repeat(3,1,1)/255
                image = torch.clip(image, 0.0, 1.0).to(device=device)
            else:#カラーノイズが60 %
                image += torch.normal(mean = 0.0, std = sigma, size=(3, image.size(dim=-2), image.size(dim=-1)),device=device)/255
                image = torch.clip(image, 0.0, 1.0).to(device=device)
            #2ndのJPEG
        kernel_size = random.choice([7,9,11,13,15,17,19,21])
        p = random.uniform(0,1)
        p3 = random.choice([0,1,2])
        if p < 0.5:
            if p3 == 0:
                image = transforms.Resize((image_h, image_w), transforms.InterpolationMode.BICUBIC)(image)
            elif p3 == 1:
                image = transforms.Resize((image_h, image_w), transforms.InterpolationMode.NEAREST)(image)
            elif p3 == 2:
                image = transforms.Resize((image_h, image_w), transforms.InterpolationMode.BILINEAR)(image)
            ps = random.uniform(0,1)
            if ps < 0.8:
                if kernel_size < 13:
                    cutoff = random.uniform(60,180)
                else:
                    cutoff = random.uniform(36,90)
                weight = calculate_gaussian_kernel(
                    kernel_size=kernel_size, 
                    cutoff= cutoff,
                    mode= "sinc"
                    )
                weight = weight.unsqueeze(0).unsqueeze(0).repeat(3,1,1,1)
                image = torch.nn.functional.conv2d(image, weight = weight,padding=kernel_size//2,groups = 3)
            else:
                pass
            quality = random.randint(30,95)
            image = degrade_image_with_jpg(image, quality = quality).to(device=device)
        else:
            ps = random.uniform(0,1)
            quality = random.randint(30,95)
            image = degrade_image_with_jpg(image, quality = quality).to(device=device)
            if p3 == 0:
                image = transforms.Resize((image_h, image_w), transforms.InterpolationMode.BICUBIC)(image)
            elif p3 == 1:
                image = transforms.Resize((image_h, image_w), transforms.InterpolationMode.NEAREST)(image)
            elif p3 == 2:
                image = transforms.Resize((image_h, image_w), transforms.InterpolationMode.BILINEAR)(image)
            if ps < 0.8:
                if kernel_size < 13:
                    cutoff = random.uniform(60,180)
                else:
                    cutoff = random.uniform(36,90)
                weight = calculate_gaussian_kernel(
                    kernel_size=kernel_size, 
                    cutoff= cutoff,
                    mode= "sinc"
                    )
                weight = weight.unsqueeze(0).unsqueeze(0).repeat(3,1,1,1)
                image = torch.nn.functional.conv2d(image, weight = weight,padding=kernel_size//2,groups = 3)

        image_trans = image
        return transforms.Resize((image.size(dim=-2) // 4, image.size(dim=-1) // 4), transforms.InterpolationMode.BICUBIC)(image_trans)

感想

ようやく、感想まで到達しました。
何より言いたいのは、楽しかったということです。
私は、知らないことを知ることに楽しみを覚えるタイプなので、日々初めて知ることがいろいろと出てきて楽しかったです。濃密なひと月でした。
初めてのコンペ参加にしてはまあ頑張ったんじゃねえかなという感じです。

愚痴を言うなら、シンプルに時間がきつすぎるということですね。最近のモデルは軒並み深いので、あー処理時間収まらないだろうしやめとこ、みたいなのが割とありましたし、何よりこれいいじゃん精度いいじゃん提出!時間オーバー!ダメです!!!は心に来ました。
現実のタスクを想定している関係で、画像サイズが可変だったのもきつかったです。そのせいでattention機構を入れる手段が取れなかったので。(一回画像全体を列ベクトルに直して、というのをやりましたが、メモリが40GBとかになって落ちました。悲しい。)

まあでも、知識は増えたし、楽しかったしで満足です。

最後に言い残すことがあるとしたらこれですかね・・・

ありがとう。RTX3060
(あとコンペ中にいろいろと話を聞いてくれた友人にマジで感謝)

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?