0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 1 year has passed since last update.

最適輸送距離を使って,簡単にスタイル変換みたいなことをしてみた

Last updated at Posted at 2023-01-22

2023/01/22:最終更新

はじめに

最適輸送距離を用いて,スタイル変換みたいなことをする方法を紹介します.
この手法自体は,以下の最適輸送距離の本に書いてあったものです.
61ClkGlinpL.SX598_BO1,204,203,200.jpg

実際にしたいことは以下の,左の2つの画像から,右の画像を生成することです.
ここでは,右の2つの画像サイズは同じであり,小さいサイズであるときのみ動くコードを紹介します.
最後のコードを実行したら動くようにしています.
connected_img.png

開発環境

Mac OS Montery v.12.6
Python 3.9.11
numpy = 1.23.3
POT = 0.8.2
Pillow = 9.4.0

目次

  1. 2つの画像のサイズを合わせる方法
  2. スタイル変換のようなことをする方法
  3. 画像を結合して保存する方法
  4. 上の3つをまとめたコード

2つの画像のサイズを合わせる方法

resize.py
from PIL import Image


def resize_img(content_img_file, style_img_file, resize=100):
    # 画像の読み込み
    content_img = Image.open(content_img_file)
    style_img = Image.open(style_img_file)

    # 画像のリサイズ
    content_img_resized = content_img.resize((resize, resize))
    style_img_resized = style_img.resize((resize, resize))

    # ファイルを保存
    content_img_resized.save('content_resized.png', quality=90)
    style_img_resized.save('style_resized.png', quality=90)

    return content_img_resized, style_img_resized


content_img_resized , style_img_resized = resize_img(content_img_file, style_img_file)

スタイル変換のようなことをする方法

transfer.py
from PIL import Image
import numpy as np
import ot


def transfer_img(content_img, style_img):
    # 画像の形式の変更
    xs = np.array(content_img, dtype='float64').reshape(-1, 3)
    xt = np.array(style_img, dtype='float64').reshape(-1, 3)

    # 各分布のサンプル数
    n = xs.shape[0]
    # 各点の重さ。今回は全て1/nとしている
    a, b = np.ones((n,)) / n, np.ones((n,)) / n
    # 距離の定義
    C = ot.dist(xs, xt)
    C /= C.max()

    # 最適な輸送方法の計算
    P = ot.emd(a, b, C)

    # Pを用いて実際に輸送してみる
    transferred_img = np.einsum('ij, ki->kj',xt, P)
    transferred_img = transferred_img*n

    # 輸送後の画像の保存
    transferred_img = np.array(transferred_img, dtype='uint8')
    transferred_img = transferred_img.reshape(int(n**(1/2)), int(n**(1/2)), 3)

    pil_img = Image.fromarray(transferred_img)
    pil_img.save('transferred_img.png')
    return pil_img


transferred_img = transfer_img(content_img_resized, style_img_resized)

画像を結合して保存する方法

connect.py
def save_img(images):
    connected_img = Image.new('RGB', (images[0].width*len(images), images[0].height))
    for i in range(len(images)):
        connected_img.paste(images[i], (images[i].width*i, 0))
    connected_img.save('connected_img.png')


save_img([content_img_resized, style_img_resized, transferred_img])

上の3つをまとめたコード

otd.py
from PIL import Image
import numpy as np
import ot


def resize_img(content_img_file, style_img_file, resize=100):
    # 画像の読み込み
    content_img = Image.open(content_img_file)
    style_img = Image.open(style_img_file)

    # 画像のリサイズ
    content_img_resized = content_img.resize((resize, resize))
    style_img_resized = style_img.resize((resize, resize))

    # ファイルを保存
    content_img_resized.save('content_resized.png', quality=90)
    style_img_resized.save('style_resized.png', quality=90)

    return content_img_resized, style_img_resized

def transfer_img(content_img, style_img):
    # 画像の形式の変更
    xs = np.array(content_img, dtype='float64').reshape(-1, 3)
    xt = np.array(style_img, dtype='float64').reshape(-1, 3)

    # 各分布のサンプル数
    n = xs.shape[0]
    # 各点の重さ。今回は全て1/nとしている
    a, b = np.ones((n,)) / n, np.ones((n,)) / n
    # 距離の定義
    C = ot.dist(xs, xt)
    C /= C.max()

    # 最適な輸送方法の計算
    P = ot.emd(a, b, C)

    # Pを用いて実際に輸送してみる
    transferred_img = np.einsum('ij, ki->kj',xt, P)
    transferred_img = transferred_img*n

    # 輸送後の画像の保存
    transferred_img = np.array(transferred_img, dtype='uint8')
    transferred_img = transferred_img.reshape(int(n**(1/2)), int(n**(1/2)), 3)

    pil_img = Image.fromarray(transferred_img)
    pil_img.save('transferred_img.png')
    return pil_img

def save_img(images):
    connected_img = Image.new('RGB', (images[0].width*len(images), images[0].height))
    for i in range(len(images)):
        connected_img.paste(images[i], (images[i].width*i, 0))
    connected_img.save('connectefdfdd_img.png')


if __name__ == '__main__':
    content_img_file = 'content.png'
    style_img_file = 'style.png'
    content_img_resized , style_img_resized = resize_img(content_img_file, style_img_file)
    transferred_img = transfer_img(content_img_resized, style_img_resized)
    save_img([content_img_resized, style_img_resized, transferred_img])

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?