Mediumの記事を参考に、一番基本のGANについて試してみた。データセットはおなじみのfashion mnist。
GANのアーキテクチャ
ノイズ画像(100次元のランダムなベクトル)からfashion画像を生成するgeneratorは、3層の全結合層から成るネットワーク。各層の出力次元数は28, 29, 210としている。
def get_generator(optimizer, output_dim=784): generator = Sequential() generator.add( Dense( 256, input_dim=random_dim, kernel_initializer=initializers.RandomNormal(stddev=0.02))) generator.add(LeakyReLU(0.2)) generator.add(Dense(512)) generator.add(LeakyReLU(0.2)) generator.add(Dense(1024)) generator.add(LeakyReLU(0.2)) generator.add(Dense(output_dim, activation='tanh')) generator.compile(loss='binary_crossentropy', optimizer=optimizer) return generator
一方のdiscriminatorも3層の全結合層から成るネットワーク。出力ノードはrealかfakeかのバイナリ値。同じように、各層の出力次元数は210, 29, 28としている。
def get_discriminator(optimizer, input_dim=784): discriminator = Sequential() discriminator.add( Dense( 1024, input_dim=input_dim, kernel_initializer=initializers.RandomNormal(stddev=0.02))) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.3)) discriminator.add(Dense(512)) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.3)) discriminator.add(Dense(256)) discriminator.add(LeakyReLU(0.2)) discriminator.add(Dropout(0.3)) discriminator.add(Dense(1, activation='sigmoid')) discriminator.compile(loss='binary_crossentropy', optimizer=optimizer) return discriminator
GANの学習
最初にdiscriminatorだけを学習してパラメータを更新した後、ノイズ画像(100次元ベクトル)からgeneratorとdiscriminatorを通してrealかfakeかを判定するend-to-endのネットワーク(get_gan_network)に対して、discriminatorは重みを固定してgeneratorだけを学習する、という二段階の処理です。
discriminatorの学習では、x_trainからランダムにピックアップしたreal画像とgeneratorで生成したfake画像をそれぞれbatch_size個用意して、それらを連結したデータ(Xとy_dis)を作成して与えています。また、ノイズ画像から生成した画像をrealと判定するようなgeneratorにするために、ラベルを1としてgeneratorを学習しています。
def train(epochs=1, batch_size=128): # Get the training and testing data (x_train, y_train, x_test, y_test), x_height_width = load_mnist_fashion() # Split the training data into batches of size 128 batch_count = x_train.shape[0] // batch_size # Build our GAN netowrk adam = get_optimizer() generator = get_generator(adam, output_dim=x_train.shape[1]) discriminator = get_discriminator(adam, x_train.shape[1]) gan = get_gan_network(discriminator, random_dim, generator, adam) for e in range(1, epochs + 1): print('-' * 15, 'Epoch %d' % e, '-' * 15) for _ in tqdm(range(batch_count)): # Get a random set of input noise and images noise = np.random.normal(0, 1, size=[batch_size, random_dim]) image_batch = x_train[np.random.randint( 0, x_train.shape[0], size=batch_size)] # Generate fake images generated_images = generator.predict(noise) X = np.concatenate([image_batch, generated_images]) # Labels for generated and real data y_dis = np.zeros(2 * batch_size) # One-sided label smoothing y_dis[:batch_size] = 1.0 # Train discriminator discriminator.trainable = True discriminator.train_on_batch(X, y_dis) # Train generator noise = np.random.normal(0, 1, size=[batch_size, random_dim]) y_gen = np.ones(batch_size) discriminator.trainable = False gan.train_on_batch(noise, y_gen) if e == 1 or e % 20 == 0: plot_generated_images(e, generator, train_shape=x_height_width)
生成された画像
batch_size=128で、300epoch学習を回した結果。
- Epoch1では、まだ単なる矩形画像にしかなっていない
- Epoch100では、各アイテムがかなり分かる程度になっているが、一部単なる矩形画像が残っている
- Epoch300では、各アイテムがはっきりと分かる程度に生成された