towardsdatascience.com
Mediumの記事を参考に、一番基本のGANについて試してみた。データセットはおなじみのfashion mnist。
ノイズ画像(100次元のランダムなベクトル)からfashion画像を生成するgeneratorは、3層の全結合層から成るネットワーク。各層の出力次元数は28, 29, 210としている。
def get_generator(optimizer, output_dim=784):
generator = Sequential()
generator.add(
Dense(
256,
input_dim=random_dim,
kernel_initializer=initializers.RandomNormal(stddev=0.02)))
generator.add(LeakyReLU(0.2))
generator.add(Dense(512))
generator.add(LeakyReLU(0.2))
generator.add(Dense(1024))
generator.add(LeakyReLU(0.2))
generator.add(Dense(output_dim, activation='tanh'))
generator.compile(loss='binary_crossentropy', optimizer=optimizer)
return generator
一方のdiscriminatorも3層の全結合層から成るネットワーク。出力ノードはrealかfakeかのバイナリ値。同じように、各層の出力次元数は210, 29, 28としている。
def get_discriminator(optimizer, input_dim=784):
discriminator = Sequential()
discriminator.add(
Dense(
1024,
input_dim=input_dim,
kernel_initializer=initializers.RandomNormal(stddev=0.02)))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(512))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(256))
discriminator.add(LeakyReLU(0.2))
discriminator.add(Dropout(0.3))
discriminator.add(Dense(1, activation='sigmoid'))
discriminator.compile(loss='binary_crossentropy', optimizer=optimizer)
return discriminator
GANの学習
最初にdiscriminatorだけを学習してパラメータを更新した後、ノイズ画像(100次元ベクトル)からgeneratorとdiscriminatorを通してrealかfakeかを判定するend-to-endのネットワーク(get_gan_network)に対して、discriminatorは重みを固定してgeneratorだけを学習する、という二段階の処理です。
discriminatorの学習では、x_trainからランダムにピックアップしたreal画像とgeneratorで生成したfake画像をそれぞれbatch_size個用意して、それらを連結したデータ(Xとy_dis)を作成して与えています。また、ノイズ画像から生成した画像をrealと判定するようなgeneratorにするために、ラベルを1としてgeneratorを学習しています。
def train(epochs=1, batch_size=128):
(x_train, y_train, x_test, y_test), x_height_width = load_mnist_fashion()
batch_count = x_train.shape[0] // batch_size
adam = get_optimizer()
generator = get_generator(adam, output_dim=x_train.shape[1])
discriminator = get_discriminator(adam, x_train.shape[1])
gan = get_gan_network(discriminator, random_dim, generator, adam)
for e in range(1, epochs + 1):
print('-' * 15, 'Epoch %d' % e, '-' * 15)
for _ in tqdm(range(batch_count)):
noise = np.random.normal(0, 1, size=[batch_size, random_dim])
image_batch = x_train[np.random.randint(
0, x_train.shape[0], size=batch_size)]
generated_images = generator.predict(noise)
X = np.concatenate([image_batch, generated_images])
y_dis = np.zeros(2 * batch_size)
y_dis[:batch_size] = 1.0
discriminator.trainable = True
discriminator.train_on_batch(X, y_dis)
noise = np.random.normal(0, 1, size=[batch_size, random_dim])
y_gen = np.ones(batch_size)
discriminator.trainable = False
gan.train_on_batch(noise, y_gen)
if e == 1 or e % 20 == 0:
plot_generated_images(e, generator, train_shape=x_height_width)
生成された画像
batch_size=128で、300epoch学習を回した結果。
- Epoch1では、まだ単なる矩形画像にしかなっていない
- Epoch100では、各アイテムがかなり分かる程度になっているが、一部単なる矩形画像が残っている
- Epoch300では、各アイテムがはっきりと分かる程度に生成された
レポジトリ
github.com