LoginSignup
1
0

CompressAIによる深層学習を用いた画像圧縮

Last updated at Posted at 2023-12-20

はじめに

この記事はLife is Tech ! Advent Calendar 2023 シリーズ2 21日目の記事です。

みなさんごきげんよう。うめちゃんと言います。本業は大学生ですが、Life is Tech!のMinecraftプログラミングコースでメンターもしております。
今日は本業である大学生としての活動に触れます。10月に研究室配属がなされ、成績優秀な私は希望通り画像処理を扱う研究室に配属されました。新人研修的なものとしてCompressAIというライブラリを用いた画像圧縮にしばらく触れたので、それについて書きます。

※本記事では深層学習の詳細な理論については触れません。

CompressAIとは

公式ドキュメントには

CompressAI (compress-ay) is a PyTorch library and evaluation platform for end-to-end compression research.

とあります。「エンドツーエンドの圧縮研究のための PyTorch ライブラリおよび評価プラットフォーム」と訳せます。
深層学習におけるエンドツーエンドとは、入力層から出力層までの重みをいっぺんに学習することを指すようです。
さらに続きを読んでいくと、

CompressAI is built on top of PyTorch and provides:

  • custom operations, layers and models for deep learning based data compression
  • a partial port of the official TensorFlow compression library
  • pre-trained end-to-end compression models for learned image compression
  • evaluation scripts to compare learned models against classical image/video compression codecs

CompressAI aims to allow more researchers to contribute to the learned image and video compression domain, by providing resources to research, implement and evaluate machine learning based compression codecs.

最後の一文だけ訳すと「CompressAI は、機械学習ベースの圧縮コーデックを研究、実装、評価するためのリソースを提供することで、より多くの研究者が学習された画像および動画圧縮の領域に貢献できるようにすることを目指しています。」となります。
機械学習ライブラリであるPyTorchを用いた、画像圧縮と動画圧縮に便利なリソースを提供してくれるのがCompressAIということになります。

使ってみる

紹介はこの程度で済ませて、まずちょっと使ってみます。

インストール

サンプルの学習用スクリプトなども含めてGitHubで公開されていますが、pipでインストールしていきます。condaでもできるようです。

pip install compressai

圧縮してみる

本来ならこの後データセットを使って学習と評価という流れになるのですが、今回はそこは省略して提供されている学習済みモデルを使っていきます。

コード

sample.py
import math
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from compressai.models import FactorizedPrior
from compressai.zoo.image import model_urls
from compressai.zoo.pretrained import load_pretrained
from pytorch_msssim import ms_ssim

def pad(x, p):
    h, w = x.size(2), x.size(3)
    new_h = (h + p - 1) // p * p
    new_w = (w + p - 1) // p * p
    padding_left = (new_w - w) // 2
    padding_right = new_w - w - padding_left
    padding_top = (new_h - h) // 2
    padding_bottom = new_h - h - padding_top
    x_padded = F.pad(
        x,
        (padding_left, padding_right, padding_top, padding_bottom),
        mode="constant",
        value=0,
    )
    return x_padded, (padding_left, padding_right, padding_top, padding_bottom)

def crop(x, padding):
    return F.pad(
        x,
        (-padding[0], -padding[1], -padding[2], -padding[3]),
    )

def compute_psnr(a, b):
    mse = torch.mean((a - b)**2).item()
    return -10 * math.log10(mse)

def compute_msssim(a, b):
    return -10 * math.log10(1-ms_ssim(a, b, data_range=1.).item())

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = FactorizedPrior(128,192)
net.to(device)
net.eval()

dictionary = {}
p = 128
url = model_urls["bmshj2018-factorized"]["mse"][1]
pretrained_dict = torch.hub.load_state_dict_from_url(url)
pretrained_dict = load_pretrained(pretrained_dict)

for k, v in pretrained_dict.items():
	dictionary[k.replace("module.", "")] = v
 
net.load_state_dict(dictionary)
net.update()

img = transforms.ToTensor()(Image.open('sample.png').convert('RGB')).to(device)
x = img.unsqueeze(0)
x_padded, padding = pad(x, p)

with torch.no_grad():
    torch.cuda.synchronize()
    out_enc = net.compress(x_padded)
    torch.save(out_enc, 'compressed')
    
    load_enc = torch.load('compressed')
    out_dec = net.decompress(load_enc["strings"], load_enc["shape"])
    out_dec["x_hat"] = crop(out_dec["x_hat"], padding)
    x_hat = transforms.functional.to_pil_image(out_dec['x_hat'].squeeze(0))
    x_hat.save('result.png')
    
    num_pixels = x.size(0) * x.size(2) * x.size(3)
    print(f'Bitrate: {(sum(len(s[0]) for s in out_enc["strings"]) * 8.0 / num_pixels):.3f}bpp')
    print(f'MS-SSIM: {compute_msssim(x, out_dec["x_hat"]):.2f}dB')
    print(f'PSNR: {compute_psnr(x, out_dec["x_hat"]):.2f}dB')

学習済みモデルはQualityが1から8まで用意されていて、ここでは最も軽く低画質なQuality=1を用います。
まず学習済みモデルを読み込みます。次に画像を読み込み、圧縮し生成されたbitstreamを保存します。その後bitstreamを保存したファイルから読み込み、展開します。展開した画像を保存し、BitrateとMS-SSIMとPSNRを算出します。

FactorizedPrior

ここではFactorizedPriorというモデルを用います。

compressai/models/google.py
@register_model("bmshj2018-factorized")
class FactorizedPrior(CompressionModel):
    r"""Factorized Prior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang,
    N. Johnston: `"Variational Image Compression with a Scale Hyperprior"
    <https://arxiv.org/abs/1802.01436>`_, Int Conf. on Learning Representations
    (ICLR), 2018.

    .. code-block:: none

                  ┌───┐    y
            x ──►─┤g_a├──►─┐
                  └───┘    │
                           ▼
                         ┌─┴─┐
                         │ Q │
                         └─┬─┘
                           │
                     y_hat ▼
                           │
                           ·
                        EB :
                           ·
                           │
                     y_hat ▼
                           │
                  ┌───┐    │
        x_hat ──◄─┤g_s├────┘
                  └───┘

        EB = Entropy bottleneck

    Args:
        N (int): Number of channels
        M (int): Number of channels in the expansion layers (last layer of the
            encoder and last layer of the hyperprior decoder)
    """
    
    def __init__(self, N, M, **kwargs):
        super().__init__(**kwargs)

        self.entropy_bottleneck = EntropyBottleneck(M)

        self.g_a = nn.Sequential(
            conv(3, N),
            GDN(N),
            conv(N, N),
            GDN(N),
            conv(N, N),
            GDN(N),
            conv(N, M),
        )

        self.g_s = nn.Sequential(
            deconv(M, N),
            GDN(N, inverse=True),
            deconv(N, N),
            GDN(N, inverse=True),
            deconv(N, N),
            GDN(N, inverse=True),
            deconv(N, 3),
        )

        self.N = N
        self.M = M

# 以下略

convおよびdeconvの中身はPyTorchのnn.Conv2dおよびnn.ConvTranspose2dです。

compressai/models/util.py
def conv(in_channels, out_channels, kernel_size=5, stride=2):
    return nn.Conv2d(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        padding=kernel_size // 2,
    )


def deconv(in_channels, out_channels, kernel_size=5, stride=2):
    return nn.ConvTranspose2d(
        in_channels,
        out_channels,
        kernel_size=kernel_size,
        stride=stride,
        output_padding=stride - 1,
        padding=kernel_size // 2,
    )

FactorizedPriorは4層の畳み込み層と逆畳み込み層で構成されていることが分かります。

結果


左が元画像、右が圧縮展開後の画像です。
コスモスの花の中央部(管状花と呼ぶらしい)と、周りの部分(舌状花と呼ぶらしい)に入っている筋が圧縮展開後では劣化しているのが見て取れます。

ファイル サイズ [KB] 圧縮率 [%]
sample.png 581 100
compressed 9.2 1.58
result.png 362 62.3

圧縮してbitstreamにするとファイルサイズが大幅に小さくなります。展開しても画像サイズは同じですが劣化している分だけ小さくなります。ただPNG形式は可逆圧縮であり、たまたまPNGでも小さくなるような劣化の仕方だっただけの可能性もあるのでこのあたりは要検証です。

1通りのQualityで試しただけではあまり意味が無いですが、BitrateとMS-SSIMとPSNRを載せておきます。

指標
Bitrate 0.113 bpp
PSNR 14.96 dB
MS-SSIM 33.27 dB

Bitrateは小さい方が良く、PSNRとMS-SSIMは大きい方が良いです。

関連

横浜国立大学 理工学部 数物・電子情報系学科 電子情報システムEP
横浜国立大学 孫研究室

1
0
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
1
0