LoginSignup
3
2

More than 5 years have passed since last update.

keras-1でpix2pix

Last updated at Posted at 2017-03-17

大元のコードはluaでtorch7なんだけど、python使いたいってことで。
https://github.com/phillipi/pix2pix

kerasのソース
https://github.com/tdeboissiere/DeepLearningImplementations/tree/master/pix2pix

必要なライブラリ

keras, theano or tensorflow backend
h5py
matplotlib
opencv 3
numpy
tqdm
parmap

kerasについて

keras-2ブランチで名前が変わってるので古いコードを動かすときは注意
Convolution2D→ Conv2Dなど

sudo pip install keras==1.2.0

pydotなどインストール

sudo apt-get install graphviz
pip install graphviz
pip install pydot==1.1.0

設定ファイルの修正

/etc/matplotlibrc
Aggに変更する。

  • backend : TkAgg
  • backend : Agg

データ処理

facadesデータセットをダウンロード

git clone https://github.com/phillipi/pix2pix.git
cd pix2pix
bash ./datasets/download_dataset.sh facades

HDF5データセットのfacadesを構築

python make_dataset.py /home/user/GitHub/pix2pix/datasets/facades 3 --img_size 256

学習

git clone https://github.com/tdeboissiere/DeepLearningImplementations.git
cd DeepLearningImplementations/pix2pix/src/model
python main.py 64 64 --backend tensorflow

スクリーンショット 2017-03-17 14.25.37.png


元のコード解析

train.lua

環境変数を解析してデフォルトを上書きする。tonumberはluaの関数で文字列から数字への変換。
http://milkpot.sakura.ne.jp/lua/lua51_manual_ja.html#lua_tonumber

for k,v in pairs(opt) do opt[k] = tonumber(os.getenv(k)) or os.getenv(k) or opt[k] end

manualSeedはtorchの関数で乱数ジェネレータのシードを指定された数に設定する。
setdefaulttensortypeはtorchの関数でデフォルトのテンソルタイプを設定します。

torch.manualSeed(opt.manualSeed)
torch.setdefaulttensortype('torch.FloatTensor')

はじめの方は大したこと書いてない。

データを取得

-- create data loader
local data_loader = paths.dofile('data/data.lua')
print('#threads...' .. opt.nThreads)
local data = data_loader.new(opt.nThreads, opt)
print("Dataset Size: ", data:size())
tmp_d, tmp_paths = data:getBatch()

重みの初期化のメソッド

local function weights_init(m)
   local name = torch.type(m)
   if name:find('Convolution') then
      m.weight:normal(0.0, 0.02)
      m.bias:fill(0)
   elseif name:find('BatchNormalization') then
      if m.weight then m.weight:normal(1.0, 0.02) end
      if m.bias then m.bias:fill(0) end
   end
end

生成器と判定器が定義されてる

function defineG(input_nc, output_nc, ngf)
    local netG = nil
    if     opt.which_model_netG == "encoder_decoder" then netG = defineG_encoder_decoder(input_nc, output_nc, ngf)
    elseif opt.which_model_netG == "unet" then netG = defineG_unet(input_nc, output_nc, ngf)
    elseif opt.which_model_netG == "unet_128" then netG = defineG_unet_128(input_nc, output_nc, ngf)
    else error("unsupported netG model")
    end

    netG:apply(weights_init)

    return netG
end

function defineD(input_nc, output_nc, ndf)
    local netD = nil
    if opt.condition_GAN==1 then
        input_nc_tmp = input_nc
    else
        input_nc_tmp = 0 -- only penalizes structure in output channels
    end

    if     opt.which_model_netD == "basic" then netD = defineD_basic(input_nc_tmp, output_nc, ndf)
    elseif opt.which_model_netD == "n_layers" then netD = defineD_n_layers(input_nc_tmp, output_nc, ndf, opt.n_layers_D)
    else error("unsupported netD model")
    end

    netD:apply(weights_init)

    return netD
end

重みの読み込み処理

if opt.continue_train == 1 then
   print('loading previously trained netG...')
   netG = util.load(paths.concat(opt.checkpoints_dir, opt.name, 'latest_net_G.t7'), opt)
   print('loading previously trained netD...')
   netD = util.load(paths.concat(opt.checkpoints_dir, opt.name, 'latest_net_D.t7'), opt)
else
  print('define model netG...')
  netG = defineG(input_nc, output_nc, ngf)
  print('define model netD...')
  netD = defineD(input_nc, output_nc, ndf)
end

画像を生成して真偽判定をするメソッド
torch.catは後ろにくっつける

function createRealFake()
    -- load real
    data_tm:reset(); data_tm:resume()
    local real_data, data_path = data:getBatch()
    data_tm:stop()

    real_A:copy(real_data[{ {}, idx_A, {}, {} }])
    real_B:copy(real_data[{ {}, idx_B, {}, {} }])

    if opt.condition_GAN==1 then
        real_AB = torch.cat(real_A,real_B,2)
    else
        real_AB = real_B -- unconditional GAN, only penalizes structure in B
    end

    -- create fake
    fake_B = netG:forward(real_A)

    if opt.condition_GAN==1 then
        fake_AB = torch.cat(real_A,fake_B,2)
    else
        fake_AB = fake_B -- unconditional GAN, only penalizes structure in B
    end
end

3
2
0

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
3
2