SNSにはびこる「バズる」写真。でも、その多くは似たような構図、テーマ、スタイルで、どれもありきたりなものに見えませんか?“映え”や“いいね”を追求するあまり、写真の多様性が忘れられようとしているのです。この問題に取り組むため、私は機械学習を使って、オリジナリティ溢れた“斬新な写真”を見つけ出す試みを行いました。
この記事作成にあたり、サークルの後輩に1000枚ほどの訓練データを提供して頂きました。ありがとうございます。
VAEについて・仮説
では、機械が"斬新な写真"を選べるようにするにはどうすればいいでしょうか。その手段として注目したのが、変分オートエンコーダ(VAE)です。VAEについては詳しい話はここでは省きます。
本稿でポイントになるのは、VAEの潜在変数zが正規分布に従うように訓練がなされるということです。オートエンコーダ(AE)というのは潜在変数zを基に入力信号を再構築しますので、zが近ければ再構築される内容も近いはずです。そこで私は以下のような仮説を立てました。
潜在変数zの値の、分布内での位置を見れば、その写真の斬新さが分かる。具体的には、
ありきたりな写真 : 多くのデータと似通ったデータなので、潜在変数zは分布の平均に近い場所に配置される。
斬新な写真 : 他とは似つかないデータなので、潜在変数zは分布の平均からは離れた場所に配置される。
この考えは一種の異常検出AIと似ています。異常な入力に対する潜在変数は、正常な入力に対する潜在変数の分布から外れているという仮定で、異常検出を行うというものです。異常検出AIに例えれば、ありきたりな写真は正常、斬新な写真は異常になります。
方法
写真を準備する
頂いた写真をVAEで読めるように準備します。そのために、いくつかの処理を行いました。
RAWの現像にはrawpy
というモジュールを使用します。↓こんな感じで使います。
import rawpy
def postprocess_raw_image(fp: Path) -> None:
output_fp = fp.with_suffix(".p.JPG")
if output_fp.exists():
print("Skip:"+fp.name)
return
try:
in_raw = rawpy.imread(str(fp))
out_raw = in_raw.postprocess()
out_image = Image.fromarray(out_raw)
out_image.save(output_fp)
except:
print("Exception:"+fp.name)
RAW形式のファイルに対してはこの関数を使ってJPG画像に変換します。
正方形への変形:正方形の画像をVAEへ入力しようと思ったので、写真を正方形にします。といっても縦横比を変える訳ではなく、黒埋めをするだけです。
def expand_to_square(img: Image.Image) -> Image.Image:
width, height = img.size
if width == height:
return img
if width > height:
new = Image.new(img.mode, (width, width))
new.paste(img, (0, (width-height)//2))
return new
# width<height
new = Image.new(img.mode, (height, height))
new.paste(img, ((height-width)//2, 0))
return new
回転:一部の写真が回転して保存されていたので、再び回転して戻す処理を行いました。
def rotate_image_array(img_array: np.ndarray, anticlockwise: bool = False) -> np.ndarray:
res = np.transpose(img_array, axes=[1, 0, 2])
if anticlockwise:
res = res[::-1]
return res
右回りしたい場合は転置を、左回りしたい場合は転置してから第1軸目の要素を逆転します。
その他、画像の縦横のピクセル数を変えたり(Image.resize
)、画素データを$[0,255]$から$[0,1]$へ正規化したりしました。
モデルを準備する
モデルはTensorFlowで作ります。
VAEのエンコーダは、入力信号に対する潜在変数$z$を出すという説明でしたが、厳密にはちょっと違くて、潜在変数の平均$\mu$と対数分散$\log\sigma^2$の2つの値が出力され、$z\sim\mathcal{N}(\mu,\sigma^2)$であることを表現します。訓練時にはこの分布に従ってランダムに$z$を決定し、逆誤差伝搬できるようにします。このことをRe-parametrization trickと言います。
def ReparametrizationTrick(args):
z_mean, z_logvar = args
eps = K.random_normal(shape=K.shape(z_logvar), mean=0, stddev=1)
return z_mean + eps * K.exp(z_logvar/2)
def build_encoder(img_len:int, z_dim:int):
inputs = Input(shape=(img_len,img_len,3))
x = Conv2D(16,8,activation="relu",strides=8,padding="same")(inputs)
x = Conv2D(32,16,activation="relu",strides=4,padding="same")(x)
x = Flatten()(x)
x = Dense(512,activation="relu")(x)
z_mean = Dense(z_dim,name="z_mean")(x)
z_logvar = Dense(z_dim,name="z_logvar")(x)
z = Lambda(ReparametrizationTrick, output_shape=(z_dim))([z_mean,z_logvar])
encoder = Model(inputs, [z_mean,z_logvar,z], name="encoder")
return encoder
def build_decoder(img_len:int, z_dim:int):
inputs = Input(shape=(z_dim,))
x = Dense(16*16*32, activation="relu")(inputs)
x = Reshape((16,16,32))(x) # 16,16,32
x = Conv2DTranspose(16,16,activation="relu",strides=4,padding="same")(x) # 64,64,16
x = Conv2DTranspose(3,8,activation="sigmoid",strides=8,padding="same")(x) # 512,512,3
decoder = Model(inputs, x, name="decoder")
assert decoder.output_shape == (None, img_len, img_len, 3)
return decoder
class VAE(Model):
def __init__(self, encoder:Model, decoder:Model, img_len:int, **kwargs):
super(VAE, self).__init__(**kwargs)
self.encoder = encoder
self.decoder = decoder
self.img_len = img_len
self.img_len2 = img_len*img_len
def train_step(self,data):
if isinstance(data,tuple):
data = data[0]
with GradientTape() as tape:
z_mean, z_logvar, z = self.encoder(data)
pred = self.decoder(z)
pred_loss = tf.reduce_mean(losses.binary_crossentropy(data, pred)) * self.img_len2
kl_loss = -0.5 * tf.reduce_mean(
1. + z_logvar - tf.square(z_mean) - tf.exp(z_logvar)
)
total_loss = pred_loss + kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
return{
"loss": total_loss,
"pred_loss": pred_loss,
"kl_loss": kl_loss
}
今回は、入力する画像は$(512,512,3)$、潜在変数は$(128,)$のサイズでいきました。
訓練する
頂いた写真を変換し、その全てを学習データとして学習させました。
IMG_LEN = 512
Z_DIM = 128
encoder = build_encoder(IMG_LEN, Z_DIM)
decoder = build_decoder(IMG_LEN, Z_DIM)
model = VAE(encoder, decoder, IMG_LEN)
model.compile(optimizer=optimizers.Adam())
history = model.fit(
train_img_s,
epochs=100,
batch_size=32,
)
斬新さを計算する
機械学習的にはご法度なのかもしれませんが、学習データを再びVAEに読み込ませて、そこから得られた$z$を使って斬新さを評価しました。
エンコーダからは潜在変数の平均と対数分散が出るという話をしました。今回のTensorFlowの実装では、それらに加えてRe-parametrization trickによって適当に選ばれた潜在変数も出るようになっています。斬新さの評価では、このうち、潜在変数の平均を潜在変数として処理しました。つまり、分散の情報は使いません。
z_mean, z_logvar, _ = model.encoder(train_img_s) # z_meanを使う
やや複雑ですが、ある画像$(i)$を読み込んで出力される潜在変数には平均と対数分散$\mu^{(i)},\log\sigma^{(i),2}$がありますが、その平均は学習データ全体で見るとやっぱり別の平均$\bar\mu$と分散$\sigma_\mu^2$に従っています。この平均の平均$\bar\mu$は私のコード内ではz_mean_mean
と表現しています。
z_mean_mean = np.mean(z_mean, axis=0)
では、斬新さの評価に入ります。その写真$(i)$の斬新さは、潜在変数の平均値z_mean[i]
が他の写真達の平均からどのくらいの距離離れているかで決めます。
で、この「距離」なんですが、今回はユークリッド距離ではなく、マハラノビス距離とかいうもので表すことにします。これは、各次元で分散が異なるような分布中でのデータ間の距離を求めるのに使われます。z_mean
も各次元の分散が結構違ったので(結果で述べる)採用しました。マハラノビス距離を計算するにはscipy
を使います。
from scipy.spatial import distance
z_mean_mean = np.mean(z_mean, axis=0)
z_mean_cov = np.cov(z_mean.numpy().T)
z_mean_cov_inv = np.linalg.pinv(z_mean_cov)
mahalanobis_s = []
for z_i in z_mean:
mahalanobis_s.append(
distance.mahalanobis(z_i,z_mean_mean,z_mean_cov_inv)
)
こうして求まったmahalanobis_s
のうち、大きいものほど斬新な写真であるという訳です。
結果
潜在変数の分布
訓練データをVAEに入力し、z_mean
のとある1つの次元の値を集め、QQプロットにしたところ、以下のようになりました。
青プロットがだいたい赤線に乗っていることから、正規分布に従っていると言えるでしょう。$z$の他の次元についても同じような結果が得られたので、$z$は正規分布に従うように出力されていることがわかります。
- 違う次元同士の共分散が低いことから、各次元は独立していると見られます。
- 同じ次元同士の共分散=分散の値は、次元によって違います。このことから、やはり、潜在変数がどのくらい平均から離れているかを評価するにはマハラノビス距離を採用すべきと考えられます。
写真
写真の質に違いが見られますね👀
考察
- 斬新な写真には被写体が際立つものが、ありきたりな写真には被写体が無いか見えづらいものが多く集まった印象です。
- 彼はよく鉄道や航空機の写真を撮影しているのですが、鉄道と航空機とでその斬新さに大きな違いが見られます。鉄道は様々な構図や場所があるのに対し、航空機は同じ場所から同じような構図で連写して撮るため、鉄道の方が斬新な写真が多い傾向にあると考えられます。
- とはいえ、本稿での方法は必ずしも人の感覚とマッチしないこともあります。例えば、彼が写真展に出していた渾身の1枚が「ありきたりな写真」に分類されていたりしていました😅確かにその写真はのっぺりとしていて他の多数の写真と類似しているかもしれませんが、被写体の構図がとても強くて、かたくまとまっている印象を受けます。このような印象を、VAEはすくい取ることはできないようです。
- 今回は後輩の撮影した写真のみで試しましたが、SNS上などで写真を集めれば、より広汎な「斬新な写真の検出」を行えるかもしれません。
付録
import numpy as np
from typing import List, Optional
import matplotlib.pyplot as plt
def grid_imshow(
img_array_s: List[np.ndarray],
title_s: Optional[List[str]],
scale: float = 3.,
show:bool = True
) -> plt.Figure:
N = len(img_array_s)
root_N = int(np.ceil(np.sqrt(N)))
figlen = int(scale*root_N)
fig, axs = plt.subplots(
root_N, root_N,
figsize=(figlen, figlen), facecolor="white"
)
axs = axs.reshape(-1)
for i, ax in enumerate(axs):
ax.set_axis_off()
if i < N:
ax.imshow(img_array_s[i])
if title_s:
ax.set_title(title_s[i])
if show:
plt.show()
return fig