GANs — Conditional GANs with MNIST (Part 4)

fernanda rodríguez
4 min readJun 16, 2020

--

Brief theoretical introduction to Conditional Generative Adversarial Nets or CGANs and practical implementation using Python and Keras/TensorFlow in Jupyter Notebook.

Conditional Generative Adversarial Nets or CGANs by mafda
Conditional Generative Adversarial Nets or CGANs by fernanda rodríguez.

In this article, you will find:

  • Research paper,
  • Definition, network design, and cost function, and
  • Training CGANs with MNIST dataset using Python and Keras/TensorFlow in Jupyter Notebook.

Research Paper

Mirza, M., & Osindero, S. (2014). Conditional Generative Adversarial Nets. ArXiv, abs/1411.1784.

Conditional Generative Adversarial Nets — CGANs

Generative adversarial nets can be extended to a conditional model if both the generator and discriminator are conditioned on some extra information y.

  • y could be any kind of auxiliary information, such as class labels or data from other modalities.

We can perform the conditioning by feeding y into the both the discriminator and generator as additional input layer.

  • Generator: The prior input noise p(z) and y are combined in joint hidden representation, and the adversarial training framework allows for considerable flexibility in how this hidden representation is composed.
  • Discriminator: x and y are presented as inputs and to a discriminative function.

Read more about GANs:

Network design

Conditional Generative Adversarial Nets or CGANs Architecture by fernanda rodríguez.
Conditional Generative Adversarial Nets or CGANs Architecture by fernanda rodríguez.

x is the real data, y class labels, and z is the latent space.

Cost function

Cost function CGANs by fernanda rodríguez.
Cost function CGANs by fernanda rodríguez.

Training CGANs

  1. Data: MNIST dataset
(X_train, y_train), (X_test, y_test) = mnist.load_data()

2. Model:

  • Generator
# Generator network
generator = Sequential()

# Input layer and hidden layer 1
generator.add(Dense(128, input_shape=(latent_dim,), kernel_initializer=init))
generator.add(LeakyReLU(alpha=0.2))
generator.add(BatchNormalization(momentum=0.8))

# Hidden layer 2 and 3
...

# Output layer
generator.add(Dense(img_dim, activation='tanh'))
  • Conditional G model
# Create label embeddings
label = Input(shape=(1,), dtype='int32')
label_embedding = Embedding(10, latent_dim)(label)
label_embedding = Flatten()(label_embedding)

# latent space
z = Input(shape=(latent_dim,))

# Output image
img = generator(multiply([z, label_embedding]))

# Generator with condition input
generator = Model([z, label], img)
  • Discriminator
# Discriminator network
discriminator = Sequential()

# Input layer and hidden layer 1
discriminator.add(Dense(128, input_shape=(img_dim,), kernel_initializer=init))
discriminator.add(LeakyReLU(alpha=0.2))

# Hidden layer 2 and 3
...
# Output layer
discriminator.add(Dense(1, activation='sigmoid'))
  • Conditional D model
# Create label embeddings
label_d = Input(shape=(1,), dtype='int32')
label_embedding_d = Embedding(10, img_dim)(label_d)
label_embedding_d = Flatten()(label_embedding_d)

# image dimension 28x28
img_d = Input(shape=(img_dim,))

# Output image
validity = discriminator(multiply([img_d, label_embedding_d]))

# Discriminator with condition input
discriminator = Model([img_d, label_d], validity)

3. Compile

optimizer = Adam(lr=0.0002, beta_1=0.5)
discriminator.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['binary_accuracy'])
discriminator.trainable = Falsevalidity = discriminator([generator([z, label]), label])

d_g = Model([z, label], validity)

d_g.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['binary_accuracy'])

4. Fit

# Train Discriminator weights
discriminator.trainable = True

# Real samples
X_batch = X_train[i*batch_size:(i+1)*batch_size]
real_labels = y_train[i*batch_size:(i+1)*batch_size].reshape(-1, 1)
d_loss_real = discriminator.train_on_batch(x=[X_batch, real_labels], y=real * (1 - smooth))

# Fake Samples
z = np.random.normal(loc=0, scale=1, size=(batch_size, latent_dim))
random_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)
X_fake = generator.predict_on_batch([z, random_labels])
d_loss_fake = discriminator.train_on_batch(x=[X_fake, random_labels], y=fake)

# Discriminator loss
d_loss_batch = 0.5 * (d_loss_real[0] + d_loss_fake[0])

# Train Generator weights
discriminator.trainable = False
z = np.random.normal(loc=0, scale=1, size=(batch_size, latent_dim))
random_labels = np.random.randint(0, 10, batch_size).reshape(-1, 1)
d_g_loss_batch = d_g.train_on_batch(x=[z, random_labels], y=real)

5. Evaluate

# plotting the metrics 
plt.plot(d_loss)
plt.plot(d_g_loss)
plt.show()

CGANs — MNIST results

epoch = 1/100, d_loss=0.699, g_loss=0.904 in CGAN_MNIST
Conditional Generative Adversarial Nets or CGANs by mafda
epoch = 100/100, d_loss=0.674, g_loss=0.929 in CGAN_MNIST

Train summary

Train summary CGANs by fernanda rodríguez.
Train summary CGANs by fernanda rodríguez.

Github repository

Look the complete training CGAN with MNIST dataset, using Python and Keras/TensorFlow in Jupyter Notebook.

--

--

fernanda rodríguez
fernanda rodríguez

Written by fernanda rodríguez

hi, i’m maría fernanda rodríguez r. multimedia engineer. data scientist. front-end dev. phd candidate: augmented reality + machine learning.