0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 3 years have passed since last update.

SwapNoise Denoising AutoEncoder, Keras [備忘録]

Posted at

背景

Kaggleコンペ [Jane Street Market Prediction]にて、Denoising AutoEncoderを試した.

現実に近いノイズを想定して、SwapNoiseを採用した.

SwapNoiseの概要

ミニバッチごとにある割合の特徴量を選択して、その他のデータとシャッフルする.

swap-noise-image_page-0001.jpg

実装

SwapNoiseを施したバッチデータを返すGeneratorを作成し、Kerasのfit_generatorで学習できる.

from typing import List, Tuple
import numpy as np
from tensorflow.keras.layers import Input, Dense, BatchNormalization, Activation
from tensorflow.keras.models import Model, Sequential
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.layers.experimental.preprocessing import Normalization
import tensorflow as tf
from tensorflow import keras

Genarator

class SwapNoiseGenerator(keras.utils.Sequence):
    def __init__(self, data: np.ndarray, batch_size: int, swap_p: float):
        self.data = data
        self.n_data, self.n_features = data.shape
        self.batch_size = batch_size
        self.swap_p = swap_p
        
    def __getitem__(self, idx: int):
        s_idx, e_idx = idx * self.batch_size, (idx + 1) * self.batch_size
        x, y = self.data[s_idx:e_idx].copy(), self.data[s_idx:e_idx].copy()
        x = self.swap(x)
        return x, y
    
    def __len__(self):
        return int(np.ceil(len(self.data) / self.batch_size))
    
    def swap(self, x: np.ndarray) -> np.ndarray:
        for idx in self.sample_swap_index():
            x[:, idx] = x[self.shuffled_index(n=x.shape[0]), idx]
        return x
    
    def sample_swap_index(self) -> List[str]:
        return np.arange(self.n_features)[np.random.rand(self.n_features) <= self.swap_p]
    
    def shuffled_index(self, n: int) -> List[int]:
        index = np.arange(n)
        np.random.shuffle(index)
        return index
    
    def on_epoch_end(self):
        np.random.shuffle(self.data)

AutoEncoder

def create_encoder(input_dim: int, emb_dim: int, n_units: List[int], lr: float):
    i = Input(input_dim)

    encoded = Dense(n_units[0])(i)
    encoded = BatchNormalization()(encoded)
    encoded = Activation('relu')(encoded)
    
    for unit in n_units[1:]:
        encoded = Dense(unit)(encoded)
        encoded = BatchNormalization()(encoded)
        encoded = Activation('relu')(encoded)
    
    encoded = Dense(emb_dim)(encoded)
    
    decoded = Dense(emb_dim)(encoded)
    decoded = BatchNormalization()(decoded)
    decoded = Activation('relu')(decoded)
    
    for unit in n_units[::-1]:
        decoded = Dense(unit)(decoded)
        decoded = BatchNormalization()(decoded)
        decoded = Activation('relu')(decoded)
    
    decoded = Dense(input_dim, activation='linear', name='decoded')(decoded)
    
    encoder = Model(inputs=i, outputs=encoded)
    autoencoder = Model(inputs=i,outputs=decoded)
    
    autoencoder.compile(optimizer=Adam(lr), loss='mse')
    return encoder, autoencoder

使用例

tr_X = ~~
va_X = ~~

EMB_DIM = 32
N_UNITS = [128, 64]
LR = 1e-4
EPOCHS = 300
SWAP_P = 0.1
BATCH_SIZE = 128

encoder, autoencoder = create_encoder(
    input_dim=tr_X.shape[1],
    emb_dim=EMB_DIM,
    n_units=N_UNITS,
    lr=LR
)

sng = SwapNoiseGenerator(data=tr_X, batch_size=BATCH_SIZE, swap_p=SWAP_P)
history = autoencoder.fit_generator(
    generator=sng,
    epochs=EPOCHS,
    verbose=1,
    callbacks=[
        EarlyStopping('val_loss', patience=10, restore_best_weights=True),
        ReduceLROnPlateau(monitor='val_loss', foctor=0.2, patience=5, min_lr=1e-5, verbose=1)
    ],
    validation_data=(va_X, va_X),
    shuffle=True,
)

参考

0
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
0
2

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?