GANs — Context-Conditional GANs with MNIST (Part 5)

fernanda rodríguez
4 min readJun 16, 2020

Brief theoretical introduction to Context-Conditional Generative Adversarial Nets or CCGANs and practical implementation using Python and Keras/TensorFlow in Jupyter Notebook.

Context-Conditional Generative Adversarial Nets or CCGANs by fernanda rodríguez.
Context-Conditional Generative Adversarial Nets or CCGANs by fernanda rodríguez.

In this article, you will find:

  • Research paper,
  • Definition, network design, and cost function, and
  • Training CCGANs with MNIST dataset using Python and Keras/TensorFlow in Jupyter Notebook.

Research Paper

Denton, E.L., Gross, S., & Fergus, R. (2016). Semi-Supervised Learning with Context-Conditional Generative Adversarial Networks. ArXiv, abs/1611.06430.

Context-Conditional Generative Adversarial Nets — CCGANs

Context-Conditional Generative Adversarial Networks (CC-GANs) are conditional GANs where

  • The Generator 𝐺 is trained to fill in a missing image patch and
  • The Generator 𝐺 and Discriminator 𝐷 are conditioned on the surrounding pixels.

CC-GANs address a different task:

  • Determining if a part of an image is real or fake given the surrounding context.

The Generator 𝐺 receives as input an image with a randomly masked out patch. The Generator 𝐺 outputs an entire image. We fill in the missing patch from the generated output and then pass the completed image into 𝐷.

Read more about GANs:

Network design

Context-Conditional Generative Adversarial Nets or CCGANs Architecture by fernanda rodríguez.
Context-Conditional Generative Adversarial Nets or CCGANs Architecture by fernanda rodríguez.

x is the real data and z is the latent space.

Cost function

Cost function CCGANs by fernanda rodríguez.
Cost function CCGANs by fernanda rodríguez.

Training CCGANs

  1. Data: MNIST dataset
(X_train, y_train), (X_test, y_test) = mnist.load_data()

2. Model:

  • Generator
# Generator network
generator = Sequential()
# Downsampling
d1 = Conv2D(gf, kernel_size=k, strides=s, padding='same')(img_g)
d1 = LeakyReLU(alpha=0.2)(d1)
d2 ...
d3 ...
d4 ...
d4 = BatchNormalization(momentum=0.8)(d4)

# Upsampling
u1 = UpSampling2D(size=2)(d4)
u1 = Conv2D(gf*4, kernel_size=k, strides=1, padding='same', activation='relu')(u1)
u1 = BatchNormalization(momentum=0.8)(u1)
u2 ...
u3 ...
u4 ...
u4 = Conv2D(1, kernel_size=4, strides=1, padding='same', activation='tanh')(u4)

generator = Model(img_g, u4)
  • Discriminator
# Discriminator network
discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=4, strides=2, padding='same', input_shape=img_shape))
discriminator.add(LeakyReLU(alpha=0.8))
# Con2D 128, Conv2D 256
...

img_d = Input(shape=img_shape)
features = discriminator(img_d)

validity = Conv2D(1, kernel_size=k, strides=1, padding='same')(features)

label = Flatten()(features)
label = Dense(num_classes+1, activation="softmax")(label)

discriminator = Model(img_d, [validity, label])

3. Compile

optimizer = Adam(lr=0.0002, beta_1=0.5)
discriminator.compile(optimizer=optimizer, loss=['mse', 'categorical_crossentropy'], loss_weights=[0.5, 0.5], metrics=['accuracy'])
# The generator takes noise as input and generates imgs
masked_img = Input(shape=(img_shape))
gen_img = generator(masked_img)
discriminator.trainable = Falsevalidity, _ = discriminator(gen_img)

d_g = Model(masked_img, validity)

d_g.compile(optimizer=optimizer, loss='mse', metrics=['accuracy'])

4. Fit

# Train Discriminator weights
discriminator.trainable = True

# Real samples
img_real = X_train[i*batch_size:(i+1)*batch_size]
real_labels = y_train[i*batch_size:(i+1)*batch_size]

d_loss_real = discriminator.train_on_batch(x=img_real, y=[real, real_labels])

# Fake Samples
masked_imgs = mask_randomly(img_real)
gen_imgs = generator.predict(masked_imgs)

d_loss_fake = discriminator.train_on_batch(x=gen_imgs, y=[fake, fake_labels])

# Discriminator loss
d_loss_batch = 0.5 * (d_loss_real[0] + d_loss_fake[0])

# Train Generator weights
discriminator.trainable = False
d_g_loss_batch = d_g.train_on_batch(x=img_real, y=real)

5. Evaluate

# plotting the metrics 
plt.plot(d_loss)
plt.plot(d_g_loss)
plt.show()

CCGANs — MNIST results

epoch = 1/100, d_loss=0.419, g_loss=0.205 in CCGAN_MNIST
Context-Conditional Generative Adversarial Nets or CCGANs by fernanda rodríguez.
epoch = 100/100, d_loss=0.063, g_loss=0.200 in CCGAN_MNIST

Train summary

Train summary CCGANs by fernanda rodríguez.
Train summary CCGANs by fernanda rodríguez.

Github repository

Look the complete training CCGAN with MNIST dataset, using Python and Keras/TensorFlow in Jupyter Notebook.

Sign up to discover human stories that deepen your understanding of the world.

Free

Distraction-free reading. No ads.

Organize your knowledge with lists and highlights.

Tell your story. Find your audience.

Membership

Read member-only stories

Support writers you read most

Earn money for your writing

Listen to audio narrations

Read offline with the Medium app

fernanda rodríguez
fernanda rodríguez

Written by fernanda rodríguez

hi, i’m maría fernanda rodríguez r. multimedia engineer. data scientist. front-end dev. phd candidate: augmented reality + machine learning.

No responses yet

Write a response