はじめに
この記事はLife is Tech ! Advent Calendar 2023 シリーズ2 21日目の記事です。
みなさんごきげんよう。うめちゃんと言います。本業は大学生ですが、Life is Tech!のMinecraftプログラミングコースでメンターもしております。
今日は本業である大学生としての活動に触れます。10月に研究室配属がなされ、成績優秀な私は希望通り画像処理を扱う研究室に配属されました。新人研修的なものとしてCompressAIというライブラリを用いた画像圧縮にしばらく触れたので、それについて書きます。
※本記事では深層学習の詳細な理論については触れません。
CompressAIとは
公式ドキュメントには
CompressAI (compress-ay) is a PyTorch library and evaluation platform for end-to-end compression research.
とあります。「エンドツーエンドの圧縮研究のための PyTorch ライブラリおよび評価プラットフォーム」と訳せます。
深層学習におけるエンドツーエンドとは、入力層から出力層までの重みをいっぺんに学習することを指すようです。
さらに続きを読んでいくと、
CompressAI is built on top of PyTorch and provides:
- custom operations, layers and models for deep learning based data compression
- a partial port of the official TensorFlow compression library
- pre-trained end-to-end compression models for learned image compression
- evaluation scripts to compare learned models against classical image/video compression codecs
CompressAI aims to allow more researchers to contribute to the learned image and video compression domain, by providing resources to research, implement and evaluate machine learning based compression codecs.
最後の一文だけ訳すと「CompressAI は、機械学習ベースの圧縮コーデックを研究、実装、評価するためのリソースを提供することで、より多くの研究者が学習された画像および動画圧縮の領域に貢献できるようにすることを目指しています。」となります。
機械学習ライブラリであるPyTorchを用いた、画像圧縮と動画圧縮に便利なリソースを提供してくれるのがCompressAIということになります。
使ってみる
紹介はこの程度で済ませて、まずちょっと使ってみます。
インストール
サンプルの学習用スクリプトなども含めてGitHubで公開されていますが、pipでインストールしていきます。condaでもできるようです。
pip install compressai
圧縮してみる
本来ならこの後データセットを使って学習と評価という流れになるのですが、今回はそこは省略して提供されている学習済みモデルを使っていきます。
コード
import math
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from compressai.models import FactorizedPrior
from compressai.zoo.image import model_urls
from compressai.zoo.pretrained import load_pretrained
from pytorch_msssim import ms_ssim
def pad(x, p):
h, w = x.size(2), x.size(3)
new_h = (h + p - 1) // p * p
new_w = (w + p - 1) // p * p
padding_left = (new_w - w) // 2
padding_right = new_w - w - padding_left
padding_top = (new_h - h) // 2
padding_bottom = new_h - h - padding_top
x_padded = F.pad(
x,
(padding_left, padding_right, padding_top, padding_bottom),
mode="constant",
value=0,
)
return x_padded, (padding_left, padding_right, padding_top, padding_bottom)
def crop(x, padding):
return F.pad(
x,
(-padding[0], -padding[1], -padding[2], -padding[3]),
)
def compute_psnr(a, b):
mse = torch.mean((a - b)**2).item()
return -10 * math.log10(mse)
def compute_msssim(a, b):
return -10 * math.log10(1-ms_ssim(a, b, data_range=1.).item())
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = FactorizedPrior(128,192)
net.to(device)
net.eval()
dictionary = {}
p = 128
url = model_urls["bmshj2018-factorized"]["mse"][1]
pretrained_dict = torch.hub.load_state_dict_from_url(url)
pretrained_dict = load_pretrained(pretrained_dict)
for k, v in pretrained_dict.items():
dictionary[k.replace("module.", "")] = v
net.load_state_dict(dictionary)
net.update()
img = transforms.ToTensor()(Image.open('sample.png').convert('RGB')).to(device)
x = img.unsqueeze(0)
x_padded, padding = pad(x, p)
with torch.no_grad():
torch.cuda.synchronize()
out_enc = net.compress(x_padded)
torch.save(out_enc, 'compressed')
load_enc = torch.load('compressed')
out_dec = net.decompress(load_enc["strings"], load_enc["shape"])
out_dec["x_hat"] = crop(out_dec["x_hat"], padding)
x_hat = transforms.functional.to_pil_image(out_dec['x_hat'].squeeze(0))
x_hat.save('result.png')
num_pixels = x.size(0) * x.size(2) * x.size(3)
print(f'Bitrate: {(sum(len(s[0]) for s in out_enc["strings"]) * 8.0 / num_pixels):.3f}bpp')
print(f'MS-SSIM: {compute_msssim(x, out_dec["x_hat"]):.2f}dB')
print(f'PSNR: {compute_psnr(x, out_dec["x_hat"]):.2f}dB')
学習済みモデルはQualityが1から8まで用意されていて、ここでは最も軽く低画質なQuality=1を用います。
まず学習済みモデルを読み込みます。次に画像を読み込み、圧縮し生成されたbitstreamを保存します。その後bitstreamを保存したファイルから読み込み、展開します。展開した画像を保存し、BitrateとMS-SSIMとPSNRを算出します。
FactorizedPrior
ここではFactorizedPriorというモデルを用います。
@register_model("bmshj2018-factorized")
class FactorizedPrior(CompressionModel):
r"""Factorized Prior model from J. Balle, D. Minnen, S. Singh, S.J. Hwang,
N. Johnston: `"Variational Image Compression with a Scale Hyperprior"
<https://arxiv.org/abs/1802.01436>`_, Int Conf. on Learning Representations
(ICLR), 2018.
.. code-block:: none
┌───┐ y
x ──►─┤g_a├──►─┐
└───┘ │
▼
┌─┴─┐
│ Q │
└─┬─┘
│
y_hat ▼
│
·
EB :
·
│
y_hat ▼
│
┌───┐ │
x_hat ──◄─┤g_s├────┘
└───┘
EB = Entropy bottleneck
Args:
N (int): Number of channels
M (int): Number of channels in the expansion layers (last layer of the
encoder and last layer of the hyperprior decoder)
"""
def __init__(self, N, M, **kwargs):
super().__init__(**kwargs)
self.entropy_bottleneck = EntropyBottleneck(M)
self.g_a = nn.Sequential(
conv(3, N),
GDN(N),
conv(N, N),
GDN(N),
conv(N, N),
GDN(N),
conv(N, M),
)
self.g_s = nn.Sequential(
deconv(M, N),
GDN(N, inverse=True),
deconv(N, N),
GDN(N, inverse=True),
deconv(N, N),
GDN(N, inverse=True),
deconv(N, 3),
)
self.N = N
self.M = M
# 以下略
conv
およびdeconv
の中身はPyTorchのnn.Conv2d
およびnn.ConvTranspose2d
です。
def conv(in_channels, out_channels, kernel_size=5, stride=2):
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=kernel_size // 2,
)
def deconv(in_channels, out_channels, kernel_size=5, stride=2):
return nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
output_padding=stride - 1,
padding=kernel_size // 2,
)
FactorizedPriorは4層の畳み込み層と逆畳み込み層で構成されていることが分かります。
結果
左が元画像、右が圧縮展開後の画像です。
コスモスの花の中央部(管状花と呼ぶらしい)と、周りの部分(舌状花と呼ぶらしい)に入っている筋が圧縮展開後では劣化しているのが見て取れます。
ファイル | サイズ [KB] | 圧縮率 [%] |
---|---|---|
sample.png | 581 | 100 |
compressed | 9.2 | 1.58 |
result.png | 362 | 62.3 |
圧縮してbitstreamにするとファイルサイズが大幅に小さくなります。展開しても画像サイズは同じですが劣化している分だけ小さくなります。ただPNG形式は可逆圧縮であり、たまたまPNGでも小さくなるような劣化の仕方だっただけの可能性もあるのでこのあたりは要検証です。
1通りのQualityで試しただけではあまり意味が無いですが、BitrateとMS-SSIMとPSNRを載せておきます。
指標 | 値 |
---|---|
Bitrate | 0.113 bpp |
PSNR | 14.96 dB |
MS-SSIM | 33.27 dB |
Bitrateは小さい方が良く、PSNRとMS-SSIMは大きい方が良いです。