2023/01/22:最終更新
はじめに
最適輸送距離を用いて,スタイル変換みたいなことをする方法を紹介します.
この手法自体は,以下の最適輸送距離の本に書いてあったものです.
実際にしたいことは以下の,左の2つの画像から,右の画像を生成することです.
ここでは,右の2つの画像サイズは同じであり,小さいサイズであるときのみ動くコードを紹介します.
最後のコードを実行したら動くようにしています.
開発環境
Mac OS Montery v.12.6
Python 3.9.11
numpy = 1.23.3
POT = 0.8.2
Pillow = 9.4.0
目次
- 2つの画像のサイズを合わせる方法
- スタイル変換のようなことをする方法
- 画像を結合して保存する方法
- 上の3つをまとめたコード
2つの画像のサイズを合わせる方法
resize.py
from PIL import Image
def resize_img(content_img_file, style_img_file, resize=100):
# 画像の読み込み
content_img = Image.open(content_img_file)
style_img = Image.open(style_img_file)
# 画像のリサイズ
content_img_resized = content_img.resize((resize, resize))
style_img_resized = style_img.resize((resize, resize))
# ファイルを保存
content_img_resized.save('content_resized.png', quality=90)
style_img_resized.save('style_resized.png', quality=90)
return content_img_resized, style_img_resized
content_img_resized , style_img_resized = resize_img(content_img_file, style_img_file)
スタイル変換のようなことをする方法
transfer.py
from PIL import Image
import numpy as np
import ot
def transfer_img(content_img, style_img):
# 画像の形式の変更
xs = np.array(content_img, dtype='float64').reshape(-1, 3)
xt = np.array(style_img, dtype='float64').reshape(-1, 3)
# 各分布のサンプル数
n = xs.shape[0]
# 各点の重さ。今回は全て1/nとしている
a, b = np.ones((n,)) / n, np.ones((n,)) / n
# 距離の定義
C = ot.dist(xs, xt)
C /= C.max()
# 最適な輸送方法の計算
P = ot.emd(a, b, C)
# Pを用いて実際に輸送してみる
transferred_img = np.einsum('ij, ki->kj',xt, P)
transferred_img = transferred_img*n
# 輸送後の画像の保存
transferred_img = np.array(transferred_img, dtype='uint8')
transferred_img = transferred_img.reshape(int(n**(1/2)), int(n**(1/2)), 3)
pil_img = Image.fromarray(transferred_img)
pil_img.save('transferred_img.png')
return pil_img
transferred_img = transfer_img(content_img_resized, style_img_resized)
画像を結合して保存する方法
connect.py
def save_img(images):
connected_img = Image.new('RGB', (images[0].width*len(images), images[0].height))
for i in range(len(images)):
connected_img.paste(images[i], (images[i].width*i, 0))
connected_img.save('connected_img.png')
save_img([content_img_resized, style_img_resized, transferred_img])
上の3つをまとめたコード
otd.py
from PIL import Image
import numpy as np
import ot
def resize_img(content_img_file, style_img_file, resize=100):
# 画像の読み込み
content_img = Image.open(content_img_file)
style_img = Image.open(style_img_file)
# 画像のリサイズ
content_img_resized = content_img.resize((resize, resize))
style_img_resized = style_img.resize((resize, resize))
# ファイルを保存
content_img_resized.save('content_resized.png', quality=90)
style_img_resized.save('style_resized.png', quality=90)
return content_img_resized, style_img_resized
def transfer_img(content_img, style_img):
# 画像の形式の変更
xs = np.array(content_img, dtype='float64').reshape(-1, 3)
xt = np.array(style_img, dtype='float64').reshape(-1, 3)
# 各分布のサンプル数
n = xs.shape[0]
# 各点の重さ。今回は全て1/nとしている
a, b = np.ones((n,)) / n, np.ones((n,)) / n
# 距離の定義
C = ot.dist(xs, xt)
C /= C.max()
# 最適な輸送方法の計算
P = ot.emd(a, b, C)
# Pを用いて実際に輸送してみる
transferred_img = np.einsum('ij, ki->kj',xt, P)
transferred_img = transferred_img*n
# 輸送後の画像の保存
transferred_img = np.array(transferred_img, dtype='uint8')
transferred_img = transferred_img.reshape(int(n**(1/2)), int(n**(1/2)), 3)
pil_img = Image.fromarray(transferred_img)
pil_img.save('transferred_img.png')
return pil_img
def save_img(images):
connected_img = Image.new('RGB', (images[0].width*len(images), images[0].height))
for i in range(len(images)):
connected_img.paste(images[i], (images[i].width*i, 0))
connected_img.save('connectefdfdd_img.png')
if __name__ == '__main__':
content_img_file = 'content.png'
style_img_file = 'style.png'
content_img_resized , style_img_resized = resize_img(content_img_file, style_img_file)
transferred_img = transfer_img(content_img_resized, style_img_resized)
save_img([content_img_resized, style_img_resized, transferred_img])