GANs — Conditional GANs with CIFAR10 (Part 9)
Brief theoretical introduction to Conditional Generative Adversarial Nets or CGANs and practical implementation using Python and Keras/TensorFlow in Jupyter Notebook.
In this article, you will find:
- Research paper,
- Definition, network design, and cost function, and
- Training CGANs with CIFAR10 dataset using Python and Keras/TensorFlow in Jupyter Notebook.
Research Paper
Mirza, M., & Osindero, S. (2014). Conditional Generative Adversarial Nets. ArXiv, abs/1411.1784.
Conditional Generative Adversarial Nets — CGANs
Generative adversarial nets can be extended to a conditional model if both the Generator G and Discriminator D are conditioned on some extra information y.
- y could be any kind of auxiliary information, such as class labels or data from other modalities.
We can perform the conditioning by feeding y into the both theGenerator G and Discriminator D as additional input layer.
- Generator: The prior input noise p(z) and y are combined in joint hidden representation, and the adversarial training framework allows for considerable flexibility in how this hidden representation is composed.
- Discriminator: x and y are presented as inputs and to a discriminative function.
Read more about GANs:
Network design
x is the real data, y class labels, and z is the latent space.
Cost function
Training CGANs
- Data: CIFAR10 dataset
(X_train, y_train), (X_test, y_test) = cifar10.load_data()
2. Model:
- Generator
# Generator network
merged_layer = Concatenate()([z, labels])
# FC: 2x2x512
generator = Dense(2*2*512, activation='relu')(merged_layer)
generator = BatchNormalization(momentum=0.9)(generator)
generator = LeakyReLU(alpha=0.1)(generator)
generator = Reshape((2, 2, 512))(generator)# Conv 1: 4x4x256
generator = Conv2DTranspose(256, kernel_size=5, strides=2, padding='same')(generator)
generator = BatchNormalization(momentum=0.9)(generator)
generator = LeakyReLU(alpha=0.1)(generator)# Conv 2, 3 and 4
...generator = Model(inputs=[z, labels], outputs=generator, name='generator')
- Discriminator
# Conv 1: 16x16x64
discriminator = Conv2D(64, kernel_size=5, strides=2, padding='same')(img_input)
discriminator = BatchNormalization(momentum=0.9)(discriminator)
discriminator = LeakyReLU(alpha=0.1)(discriminator)
# Conv 2, 3 and 4
...# FC
discriminator = Flatten()(discriminator)
# Concatenate
merged_layer = Concatenate()([discriminator, labels])
discriminator = Dense(512, activation='relu')(merged_layer)
# Output
discriminator = Dense(1, activation='sigmoid')(discriminator)
discriminator = Model(inputs=[img_input, labels], outputs=discriminator, name='discriminator')
3. Compile
discriminator.compile(Adam(lr=0.0002, beta_1=0.5), loss='binary_crossentropy',metrics=['binary_accuracy'])discriminator.trainable = Falselabel = Input(shape=(10,), name='label')
z = Input(shape=(100,), name='z')
fake_img = generator([z, label])
validity = discriminator([fake_img, label])
d_g = Model([z, label], validity, name='adversarial')
d_g.compile(Adam(lr=0.0004, beta_1=0.5), loss='binary_crossentropy',
metrics=['binary_accuracy'])
4. Fit
# Train Discriminator weights
discriminator.trainable = True
# Real samples
X_batch = X_train[i*batch_size:(i+1)*batch_size]
real_labels = to_categorical(y_train[i*batch_size:(i+1)*batch_size].reshape(-1, 1), num_classes=10)d_loss_real = discriminator.train_on_batch(x=[X_batch, real_labels], y=real * (1 - smooth))
# Fake Samples
z = np.random.normal(loc=0, scale=1, size=(batch_size, latent_dim))
random_labels = to_categorical(np.random.randint(0, 10, batch_size).reshape(-1, 1), num_classes=10)
X_fake = generator.predict_on_batch([z, random_labels])d_loss_fake = discriminator.train_on_batch(x=[X_fake, random_labels], y=fake)
# Discriminator loss
d_loss_batch = 0.5 * (d_loss_real[0] + d_loss_fake[0])
# Train Generator weights
discriminator.trainable = Falsez = np.random.normal(loc=0, scale=1, size=(batch_size, latent_dim))
random_labels = to_categorical(np.random.randint(0, 10, batch_size).reshape(-1, 1), num_classes=10)
d_g_loss_batch = d_g.train_on_batch(x=[z, random_labels], y=real)
5. Evaluate
# plotting the metrics
plt.plot(d_loss)
plt.plot(d_g_loss)
plt.show()
CGANs — CIFAR10 results
Train summary
Github repository
Look the complete training CGAN with CIFAR10 dataset, using Python and Keras/TensorFlow in Jupyter Notebook.
For those looking for all the articles in our GANs series. Here is the link.