8
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?

More than 5 years have passed since last update.

Keras 2 で”はじめてのGAN”

Posted at

はじめに

なんやかんや作ったことのなかったGANを研究の過程で必要になったので作ることにしました。ということで、軽くネットで入門をあさったところ、つぎのサイトがヒット。良さそうでした。
はじめてのGAN
しかし、ながら少しばかし古いためか、kerasのバージョンが古い。(といっても昨年2017年の記事だが)

そこで、バージョンの問題を解消して、keras2用のコードを、ここにメモとして残す
本記事のkerasのバージョンは2.2.4です。

実装

概念的な変更点はデフォルトがチャンネルを最後の次元に持ってきていたので、本家はチャンネルが一番最初に来ていたのを最後に変えた。
また、初期化の値もReluを使うので、Heの初期値に置き換えた。

generator

def generator_model():
    data_format = 'channels_last'
    model = Sequential()
    model.add(Dense(input_shape=(100,),units=1024, kernel_initializer='he_normal'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Dense(128*7*7))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(Reshape((7, 7, 128), input_shape=(128*7*7,)))
    model.add(UpSampling2D((2, 2), data_format=data_format))
    model.add(Conv2D(64, (5, 5), padding='same', data_format=data_format, kernel_initializer='he_normal'))
    model.add(BatchNormalization())
    model.add(Activation('relu'))
    model.add(UpSampling2D((2, 2), data_format=data_format))
    model.add(Conv2D(1, (5, 5), padding='same', data_format=data_format, kernel_initializer='he_normal'))
    model.add(Activation('tanh'))
    return model

discriminator

def discriminator_model():
    data_format = 'channels_last'
    model = Sequential()
    model.add(Conv2D(64, (5, 5), strides=(2, 2), padding='same', data_format=data_format,
                            input_shape=(28, 28, 1), kernel_initializer='he_normal'))
    model.add(LeakyReLU(0.2))
    model.add(Conv2D(128, (5, 5), strides=(2, 2), data_format=data_format, kernel_initializer='he_normal'))
    model.add(LeakyReLU(0.2))
    model.add(Flatten())
    model.add(Dense(256, kernel_initializer='he_normal'))
    model.add(LeakyReLU(0.2))
    model.add(Dropout(0.5))
    model.add(Dense(1))
    model.add(Activation('sigmoid'))
    return model

cmbine_images

def combine_images(generated_images):
    total = generated_images.shape[0]
    cols = int(math.sqrt(total))
    rows = math.ceil(float(total)/cols)
    width, height = generated_images.shape[1:3]
    combined_image = np.zeros((height*rows, width*cols),
                              dtype=generated_images.dtype)

    for index, image in enumerate(generated_images):
        i = int(index/cols)
        j = index % cols
        combined_image[width*i:width*(i+1), height*j:height*(j+1)] = image[ :, :, 0]
    return combined_image

train

def train():
    (X_train, y_train), (_, _) = mnist.load_data()
    X_train = (X_train.astype(np.float32) - 127.5)/127.5
    X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], X_train.shape[2], 1)

    discriminator = discriminator_model()
    d_opt = Adam(lr=1e-5, beta_1=0.1)
    discriminator.compile(loss='binary_crossentropy', optimizer=d_opt)

    # generator+discriminator (discriminator部分の重みは固定)
    discriminator.trainable = False
    generator = generator_model()
    dcgan = Sequential([generator, discriminator])
    g_opt = Adam(lr=2e-4, beta_1=0.5)
    dcgan.compile(loss='binary_crossentropy', optimizer=g_opt)

    num_batches = int(X_train.shape[0] / BATCH_SIZE)
    print('Number of batches:', num_batches)
    for epoch in range(NUM_EPOCH):

        for index in range(num_batches):
            noise = np.array([np.random.uniform(-1, 1, 100) for _ in range(BATCH_SIZE)])
            image_batch = X_train[index*BATCH_SIZE:(index+1)*BATCH_SIZE]
            generated_images = generator.predict(noise, verbose=0)

            # 生成画像を出力
            if index % 200 == 0:
                image = combine_images(generated_images)
                image = image*127.5 + 127.5
                if not os.path.exists(GENERATED_IMAGE_PATH):
                    os.mkdir(GENERATED_IMAGE_PATH)
                Image.fromarray(image.astype(np.uint8))\
                    .save(GENERATED_IMAGE_PATH+"%04d_%04d.png" % (epoch, index))

            # discriminatorを更新
            X = np.concatenate((image_batch, generated_images))
            y = [1]*BATCH_SIZE + [0]*BATCH_SIZE
            d_loss = discriminator.train_on_batch(X, y)

            # generatorを更新
            noise = np.array([np.random.uniform(-1, 1, 100) for _ in range(BATCH_SIZE)])
            g_loss = dcgan.train_on_batch(noise, [1]*BATCH_SIZE)
            print("epoch: %d, batch: %d, g_loss: %f, d_loss: %f" % (epoch, index, g_loss, d_loss))

        generator.save_weights('generator.h5')
        discriminator.save_weights('discriminator.h5')
8
9
1

Register as a new user and use Qiita more conveniently

  1. You get articles that match your needs
  2. You can efficiently read back useful information
  3. You can use dark theme
What you can do with signing up
8
9

Delete article

Deleted articles cannot be recovered.

Draft of this article would be also deleted.

Are you sure you want to delete this article?