GANs — Conditional GANs with CIFAR10 (Part 9)

fernanda rodríguez
4 min readJun 16, 2020

Brief theoretical introduction to Conditional Generative Adversarial Nets or CGANs and practical implementation using Python and Keras/TensorFlow in Jupyter Notebook.

Conditional Generative Adversarial Nets or CGANs by fernanda rodríguez.
Conditional Generative Adversarial Nets or CGANs by fernanda rodríguez.

In this article, you will find:

  • Research paper,
  • Definition, network design, and cost function, and
  • Training CGANs with CIFAR10 dataset using Python and Keras/TensorFlow in Jupyter Notebook.

Research Paper

Mirza, M., & Osindero, S. (2014). Conditional Generative Adversarial Nets. ArXiv, abs/1411.1784.

Conditional Generative Adversarial Nets — CGANs

Generative adversarial nets can be extended to a conditional model if both the Generator G and Discriminator D are conditioned on some extra information y.

  • y could be any kind of auxiliary information, such as class labels or data from other modalities.

We can perform the conditioning by feeding y into the both theGenerator G and Discriminator D as additional input layer.

  • Generator: The prior input noise p(z) and y are combined in joint hidden representation, and the adversarial training framework allows for considerable flexibility in how this hidden representation is composed.
  • Discriminator: x and y are presented as inputs and to a discriminative function.

Read more about GANs:

Network design

Conditional Generative Adversarial Nets or CGANs Architecture by fernanda rodríguez.
Conditional Generative Adversarial Nets or CGANs Architecture by fernanda rodríguez.

x is the real data, y class labels, and z is the latent space.

Cost function

Cost function CGANs by fernanda rodríguez.
Cost function CGANs by fernanda rodríguez.

Training CGANs

  1. Data: CIFAR10 dataset
(X_train, y_train), (X_test, y_test) = cifar10.load_data()

2. Model:

  • Generator
# Generator network
merged_layer = Concatenate()([z, labels])

# FC: 2x2x512
generator = Dense(2*2*512, activation='relu')(merged_layer)
generator = BatchNormalization(momentum=0.9)(generator)
generator = LeakyReLU(alpha=0.1)(generator)
generator = Reshape((2, 2, 512))(generator)
# Conv 1: 4x4x256
generator = Conv2DTranspose(256, kernel_size=5, strides=2, padding='same')(generator)
generator = BatchNormalization(momentum=0.9)(generator)
generator = LeakyReLU(alpha=0.1)(generator)
# Conv 2, 3 and 4
...
generator = Model(inputs=[z, labels], outputs=generator, name='generator')
  • Discriminator
# Conv 1: 16x16x64
discriminator = Conv2D(64, kernel_size=5, strides=2, padding='same')(img_input)
discriminator = BatchNormalization(momentum=0.9)(discriminator)
discriminator = LeakyReLU(alpha=0.1)(discriminator)

# Conv 2, 3 and 4
...
# FC
discriminator = Flatten()(discriminator)

# Concatenate
merged_layer = Concatenate()([discriminator, labels])
discriminator = Dense(512, activation='relu')(merged_layer)

# Output
discriminator = Dense(1, activation='sigmoid')(discriminator)

discriminator = Model(inputs=[img_input, labels], outputs=discriminator, name='discriminator')

3. Compile

discriminator.compile(Adam(lr=0.0002, beta_1=0.5), loss='binary_crossentropy',metrics=['binary_accuracy'])discriminator.trainable = Falselabel = Input(shape=(10,), name='label')
z = Input(shape=(100,), name='z')

fake_img = generator([z, label])
validity = discriminator([fake_img, label])

d_g = Model([z, label], validity, name='adversarial')

d_g.compile(Adam(lr=0.0004, beta_1=0.5), loss='binary_crossentropy',
metrics=['binary_accuracy'])

4. Fit

# Train Discriminator weights
discriminator.trainable = True

# Real samples
X_batch = X_train[i*batch_size:(i+1)*batch_size]
real_labels = to_categorical(y_train[i*batch_size:(i+1)*batch_size].reshape(-1, 1), num_classes=10)
d_loss_real = discriminator.train_on_batch(x=[X_batch, real_labels], y=real * (1 - smooth))

# Fake Samples
z = np.random.normal(loc=0, scale=1, size=(batch_size, latent_dim))
random_labels = to_categorical(np.random.randint(0, 10, batch_size).reshape(-1, 1), num_classes=10)
X_fake = generator.predict_on_batch([z, random_labels])
d_loss_fake = discriminator.train_on_batch(x=[X_fake, random_labels], y=fake)

# Discriminator loss
d_loss_batch = 0.5 * (d_loss_real[0] + d_loss_fake[0])

# Train Generator weights
discriminator.trainable = False
z = np.random.normal(loc=0, scale=1, size=(batch_size, latent_dim))
random_labels = to_categorical(np.random.randint(0, 10, batch_size).reshape(-1, 1), num_classes=10)
d_g_loss_batch = d_g.train_on_batch(x=[z, random_labels], y=real)

5. Evaluate

# plotting the metrics 
plt.plot(d_loss)
plt.plot(d_g_loss)
plt.show()

CGANs — CIFAR10 results

epoch = 1/100, d_loss=0.291, g_loss=4.552 in CGAN_CIFAR10
Conditional Generative Adversarial Nets or CGANs by fernanda rodríguez.
epoch = 100/100, d_loss=0.188, g_loss=7.892 in CGAN_CIFAR10

Train summary

Train summary CCGAN by fernanda rodríguez
Train summary CCGAN by fernanda rodríguez

Github repository

Look the complete training CGAN with CIFAR10 dataset, using Python and Keras/TensorFlow in Jupyter Notebook.

Sign up to discover human stories that deepen your understanding of the world.

Free

Distraction-free reading. No ads.

Organize your knowledge with lists and highlights.

Tell your story. Find your audience.

Membership

Read member-only stories

Support writers you read most

Earn money for your writing

Listen to audio narrations

Read offline with the Medium app

fernanda rodríguez
fernanda rodríguez

Written by fernanda rodríguez

hi, i’m maría fernanda rodríguez r. multimedia engineer. data scientist. front-end dev. phd candidate: augmented reality + machine learning.

No responses yet

Write a response