何をするものか
GANでテーブルデータをaugmentationしてやろうというものです。
GANが学習できるぐらいデータ量があるならあんまり効かないんじゃないかという話もありますがこちらの記事によるとそこそこ効いてくれるようです。
https://qiita.com/jovyan/items/c41ab61a6b04e9a6e4df
TGAN
github repository
https://github.com/sdv-dev/TGAN
1. インストール
pipで入ります
pip install tgan
2. 使い方
2-1. データの準備
付属しているサンプルデータでやってみます
from tgan.data import load_demo_data
data, continuous_columns = load_demo_data('census')
print(data.shape)
data.head(3)
2-2. モデルをつくる
データは連続値と離散値の2つに分けて扱われるので、どの列を連続値として扱ってほしいか指定します。
from tgan.model import TGANModel
tgan = TGANModel(continuous_columns)
2-3. 学習する
fit()すればGANモデルの学習がスタートしますが、データに欠損値やinfがあるとエラーになるので適当な前処理が必要です。このデータは1行だけ欠損値が入ったデータがあるのでdropna()して学習してみます。
tgan.fit(data.dropna(axis=0))
2-4. データを生成する
あとはいくつ生成するか指定してGANモデルにデータを作ってもらうだけです
num_samples = 1000
samples = tgan.sample(num_samples)
samples.head(3)
2-5. モデルの保存
作ったGANモデルは以下のようにして保存できます
model_path = 'models/mymodel.pkl'
tgan.save(model_path)
model_path = 'models/mymodel.pkl'
tgan.save(model_path, force=True)
2-6. モデルの読み込み
new_tgan = TGANModel.load(model_path)
new_samples = new_tgan.sample(num_samples)
new_samples.head(3)
CTGAN
TGANの改良版です
github repository
https://github.com/sdv-dev/CTGAN
1. インストール
pipで入ります
pip install ctgan
2. 使い方
2-1. 学習~生成
TGANとは逆で離散値列のlistを渡して指定します。
from ctgan import CTGANSynthesizer
from ctgan import load_demo
data = load_demo()
# Names of the columns that are discrete
discrete_columns = [
'workclass',
'education',
'marital-status',
'occupation',
'relationship',
'race',
'sex',
'native-country',
'income'
]
ctgan = CTGANSynthesizer(epochs=10)
ctgan.fit(data, discrete_columns)
# Synthetic copy
samples = ctgan.sample(1000)
2-2. モデルの保存
ctgan.save('my_model.pkl')
2-3. モデルの読み込み
ctgan2 = CTGANSynthesizer().load('my_model.pkl')
ctgan2