stMind

about Tech, Computer vision and Machine learning

KerasでGAN

towardsdatascience.com

Mediumの記事を参考に、一番基本のGANについて試してみた。データセットはおなじみのfashion mnist。

GANのアーキテクチャ

ノイズ画像(100次元のランダムなベクトル)からfashion画像を生成するgeneratorは、3層の全結合層から成るネットワーク。各層の出力次元数は28, 29, 210としている。

def get_generator(optimizer, output_dim=784):
    generator = Sequential()
    generator.add(
        Dense(
            256,
            input_dim=random_dim,
            kernel_initializer=initializers.RandomNormal(stddev=0.02)))
    generator.add(LeakyReLU(0.2))

    generator.add(Dense(512))
    generator.add(LeakyReLU(0.2))

    generator.add(Dense(1024))
    generator.add(LeakyReLU(0.2))

    generator.add(Dense(output_dim, activation='tanh'))
    generator.compile(loss='binary_crossentropy', optimizer=optimizer)
    return generator

一方のdiscriminatorも3層の全結合層から成るネットワーク。出力ノードはrealかfakeかのバイナリ値。同じように、各層の出力次元数は210, 29, 28としている。

def get_discriminator(optimizer, input_dim=784):
    discriminator = Sequential()
    discriminator.add(
        Dense(
            1024,
            input_dim=input_dim,
            kernel_initializer=initializers.RandomNormal(stddev=0.02)))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))

    discriminator.add(Dense(512))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))

    discriminator.add(Dense(256))
    discriminator.add(LeakyReLU(0.2))
    discriminator.add(Dropout(0.3))

    discriminator.add(Dense(1, activation='sigmoid'))
    discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)
    return discriminator

GANの学習

最初にdiscriminatorだけを学習してパラメータを更新した後、ノイズ画像(100次元ベクトル)からgeneratorとdiscriminatorを通してrealかfakeかを判定するend-to-endのネットワーク(get_gan_network)に対して、discriminatorは重みを固定してgeneratorだけを学習する、という二段階の処理です。

discriminatorの学習では、x_trainからランダムにピックアップしたreal画像とgeneratorで生成したfake画像をそれぞれbatch_size個用意して、それらを連結したデータ(Xとy_dis)を作成して与えています。また、ノイズ画像から生成した画像をrealと判定するようなgeneratorにするために、ラベルを1としてgeneratorを学習しています。

def train(epochs=1, batch_size=128):
    # Get the training and testing data
    (x_train, y_train, x_test, y_test), x_height_width = load_mnist_fashion()
    # Split the training data into batches of size 128
    batch_count = x_train.shape[0] // batch_size

    # Build our GAN netowrk
    adam = get_optimizer()
    generator = get_generator(adam, output_dim=x_train.shape[1])
    discriminator = get_discriminator(adam, x_train.shape[1])
    gan = get_gan_network(discriminator, random_dim, generator, adam)

    for e in range(1, epochs + 1):
        print('-' * 15, 'Epoch %d' % e, '-' * 15)
        for _ in tqdm(range(batch_count)):
            # Get a random set of input noise and images
            noise = np.random.normal(0, 1, size=[batch_size, random_dim])
            image_batch = x_train[np.random.randint(
                0, x_train.shape[0], size=batch_size)]

            # Generate fake images
            generated_images = generator.predict(noise)
            X = np.concatenate([image_batch, generated_images])

            # Labels for generated and real data
            y_dis = np.zeros(2 * batch_size)
            # One-sided label smoothing
            y_dis[:batch_size] = 1.0

            # Train discriminator
            discriminator.trainable = True
            discriminator.train_on_batch(X, y_dis)

            # Train generator
            noise = np.random.normal(0, 1, size=[batch_size, random_dim])
            y_gen = np.ones(batch_size)
            discriminator.trainable = False
            gan.train_on_batch(noise, y_gen)

        if e == 1 or e % 20 == 0:
            plot_generated_images(e, generator, train_shape=x_height_width)

生成された画像

batch_size=128で、300epoch学習を回した結果。

  • Epoch1では、まだ単なる矩形画像にしかなっていない

f:id:satojkovic:20180521235550p:plain

  • Epoch100では、各アイテムがかなり分かる程度になっているが、一部単なる矩形画像が残っている

f:id:satojkovic:20180521235705p:plain

  • Epoch300では、各アイテムがはっきりと分かる程度に生成された

f:id:satojkovic:20180521235922p:plain

レポジトリ

github.com