GANs — Context-Conditional GANs with MNIST (Part 5)

fernanda rodríguez
4 min readJun 16, 2020

--

Brief theoretical introduction to Context-Conditional Generative Adversarial Nets or CCGANs and practical implementation using Python and Keras/TensorFlow in Jupyter Notebook.

Context-Conditional Generative Adversarial Nets or CCGANs by fernanda rodríguez.
Context-Conditional Generative Adversarial Nets or CCGANs by fernanda rodríguez.

In this article, you will find:

  • Research paper,
  • Definition, network design, and cost function, and
  • Training CCGANs with MNIST dataset using Python and Keras/TensorFlow in Jupyter Notebook.

Research Paper

Denton, E.L., Gross, S., & Fergus, R. (2016). Semi-Supervised Learning with Context-Conditional Generative Adversarial Networks. ArXiv, abs/1611.06430.

Context-Conditional Generative Adversarial Nets — CCGANs

Context-Conditional Generative Adversarial Networks (CC-GANs) are conditional GANs where

  • The Generator 𝐺 is trained to fill in a missing image patch and
  • The Generator 𝐺 and Discriminator 𝐷 are conditioned on the surrounding pixels.

CC-GANs address a different task:

  • Determining if a part of an image is real or fake given the surrounding context.

The Generator 𝐺 receives as input an image with a randomly masked out patch. The Generator 𝐺 outputs an entire image. We fill in the missing patch from the generated output and then pass the completed image into 𝐷.

Read more about GANs:

Network design

Context-Conditional Generative Adversarial Nets or CCGANs Architecture by fernanda rodríguez.
Context-Conditional Generative Adversarial Nets or CCGANs Architecture by fernanda rodríguez.

x is the real data and z is the latent space.

Cost function

Cost function CCGANs by fernanda rodríguez.
Cost function CCGANs by fernanda rodríguez.

Training CCGANs

  1. Data: MNIST dataset
(X_train, y_train), (X_test, y_test) = mnist.load_data()

2. Model:

  • Generator
# Generator network
generator = Sequential()
# Downsampling
d1 = Conv2D(gf, kernel_size=k, strides=s, padding='same')(img_g)
d1 = LeakyReLU(alpha=0.2)(d1)
d2 ...
d3 ...
d4 ...
d4 = BatchNormalization(momentum=0.8)(d4)

# Upsampling
u1 = UpSampling2D(size=2)(d4)
u1 = Conv2D(gf*4, kernel_size=k, strides=1, padding='same', activation='relu')(u1)
u1 = BatchNormalization(momentum=0.8)(u1)
u2 ...
u3 ...
u4 ...
u4 = Conv2D(1, kernel_size=4, strides=1, padding='same', activation='tanh')(u4)

generator = Model(img_g, u4)
  • Discriminator
# Discriminator network
discriminator = Sequential()
discriminator.add(Conv2D(64, kernel_size=4, strides=2, padding='same', input_shape=img_shape))
discriminator.add(LeakyReLU(alpha=0.8))
# Con2D 128, Conv2D 256
...

img_d = Input(shape=img_shape)
features = discriminator(img_d)

validity = Conv2D(1, kernel_size=k, strides=1, padding='same')(features)

label = Flatten()(features)
label = Dense(num_classes+1, activation="softmax")(label)

discriminator = Model(img_d, [validity, label])

3. Compile

optimizer = Adam(lr=0.0002, beta_1=0.5)
discriminator.compile(optimizer=optimizer, loss=['mse', 'categorical_crossentropy'], loss_weights=[0.5, 0.5], metrics=['accuracy'])
# The generator takes noise as input and generates imgs
masked_img = Input(shape=(img_shape))
gen_img = generator(masked_img)
discriminator.trainable = Falsevalidity, _ = discriminator(gen_img)

d_g = Model(masked_img, validity)

d_g.compile(optimizer=optimizer, loss='mse', metrics=['accuracy'])

4. Fit

# Train Discriminator weights
discriminator.trainable = True

# Real samples
img_real = X_train[i*batch_size:(i+1)*batch_size]
real_labels = y_train[i*batch_size:(i+1)*batch_size]

d_loss_real = discriminator.train_on_batch(x=img_real, y=[real, real_labels])

# Fake Samples
masked_imgs = mask_randomly(img_real)
gen_imgs = generator.predict(masked_imgs)

d_loss_fake = discriminator.train_on_batch(x=gen_imgs, y=[fake, fake_labels])

# Discriminator loss
d_loss_batch = 0.5 * (d_loss_real[0] + d_loss_fake[0])

# Train Generator weights
discriminator.trainable = False
d_g_loss_batch = d_g.train_on_batch(x=img_real, y=real)

5. Evaluate

# plotting the metrics 
plt.plot(d_loss)
plt.plot(d_g_loss)
plt.show()

CCGANs — MNIST results

epoch = 1/100, d_loss=0.419, g_loss=0.205 in CCGAN_MNIST
Context-Conditional Generative Adversarial Nets or CCGANs by fernanda rodríguez.
epoch = 100/100, d_loss=0.063, g_loss=0.200 in CCGAN_MNIST

Train summary

Train summary CCGANs by fernanda rodríguez.
Train summary CCGANs by fernanda rodríguez.

Github repository

Look the complete training CCGAN with MNIST dataset, using Python and Keras/TensorFlow in Jupyter Notebook.

--

--

fernanda rodríguez

hi, i’m maría fernanda rodríguez r. multimedia engineer. data scientist. front-end dev. phd candidate: augmented reality + machine learning.