GANs — Conditional GANs with CIFAR10 (Part 9)

fernanda rodríguez
4 min readJun 16, 2020

--

Brief theoretical introduction to Conditional Generative Adversarial Nets or CGANs and practical implementation using Python and Keras/TensorFlow in Jupyter Notebook.

Conditional Generative Adversarial Nets or CGANs by fernanda rodríguez.
Conditional Generative Adversarial Nets or CGANs by fernanda rodríguez.

In this article, you will find:

  • Research paper,
  • Definition, network design, and cost function, and
  • Training CGANs with CIFAR10 dataset using Python and Keras/TensorFlow in Jupyter Notebook.

Research Paper

Mirza, M., & Osindero, S. (2014). Conditional Generative Adversarial Nets. ArXiv, abs/1411.1784.

Conditional Generative Adversarial Nets — CGANs

Generative adversarial nets can be extended to a conditional model if both the Generator G and Discriminator D are conditioned on some extra information y.

  • y could be any kind of auxiliary information, such as class labels or data from other modalities.

We can perform the conditioning by feeding y into the both theGenerator G and Discriminator D as additional input layer.

  • Generator: The prior input noise p(z) and y are combined in joint hidden representation, and the adversarial training framework allows for considerable flexibility in how this hidden representation is composed.
  • Discriminator: x and y are presented as inputs and to a discriminative function.

Read more about GANs:

Network design

Conditional Generative Adversarial Nets or CGANs Architecture by fernanda rodríguez.
Conditional Generative Adversarial Nets or CGANs Architecture by fernanda rodríguez.

x is the real data, y class labels, and z is the latent space.

Cost function

Cost function CGANs by fernanda rodríguez.
Cost function CGANs by fernanda rodríguez.

Training CGANs

  1. Data: CIFAR10 dataset
(X_train, y_train), (X_test, y_test) = cifar10.load_data()

2. Model:

  • Generator
# Generator network
merged_layer = Concatenate()([z, labels])

# FC: 2x2x512
generator = Dense(2*2*512, activation='relu')(merged_layer)
generator = BatchNormalization(momentum=0.9)(generator)
generator = LeakyReLU(alpha=0.1)(generator)
generator = Reshape((2, 2, 512))(generator)
# Conv 1: 4x4x256
generator = Conv2DTranspose(256, kernel_size=5, strides=2, padding='same')(generator)
generator = BatchNormalization(momentum=0.9)(generator)
generator = LeakyReLU(alpha=0.1)(generator)
# Conv 2, 3 and 4
...
generator = Model(inputs=[z, labels], outputs=generator, name='generator')
  • Discriminator
# Conv 1: 16x16x64
discriminator = Conv2D(64, kernel_size=5, strides=2, padding='same')(img_input)
discriminator = BatchNormalization(momentum=0.9)(discriminator)
discriminator = LeakyReLU(alpha=0.1)(discriminator)

# Conv 2, 3 and 4
...
# FC
discriminator = Flatten()(discriminator)

# Concatenate
merged_layer = Concatenate()([discriminator, labels])
discriminator = Dense(512, activation='relu')(merged_layer)

# Output
discriminator = Dense(1, activation='sigmoid')(discriminator)

discriminator = Model(inputs=[img_input, labels], outputs=discriminator, name='discriminator')

3. Compile

discriminator.compile(Adam(lr=0.0002, beta_1=0.5), loss='binary_crossentropy',metrics=['binary_accuracy'])discriminator.trainable = Falselabel = Input(shape=(10,), name='label')
z = Input(shape=(100,), name='z')

fake_img = generator([z, label])
validity = discriminator([fake_img, label])

d_g = Model([z, label], validity, name='adversarial')

d_g.compile(Adam(lr=0.0004, beta_1=0.5), loss='binary_crossentropy',
metrics=['binary_accuracy'])

4. Fit

# Train Discriminator weights
discriminator.trainable = True

# Real samples
X_batch = X_train[i*batch_size:(i+1)*batch_size]
real_labels = to_categorical(y_train[i*batch_size:(i+1)*batch_size].reshape(-1, 1), num_classes=10)
d_loss_real = discriminator.train_on_batch(x=[X_batch, real_labels], y=real * (1 - smooth))

# Fake Samples
z = np.random.normal(loc=0, scale=1, size=(batch_size, latent_dim))
random_labels = to_categorical(np.random.randint(0, 10, batch_size).reshape(-1, 1), num_classes=10)
X_fake = generator.predict_on_batch([z, random_labels])
d_loss_fake = discriminator.train_on_batch(x=[X_fake, random_labels], y=fake)

# Discriminator loss
d_loss_batch = 0.5 * (d_loss_real[0] + d_loss_fake[0])

# Train Generator weights
discriminator.trainable = False
z = np.random.normal(loc=0, scale=1, size=(batch_size, latent_dim))
random_labels = to_categorical(np.random.randint(0, 10, batch_size).reshape(-1, 1), num_classes=10)
d_g_loss_batch = d_g.train_on_batch(x=[z, random_labels], y=real)

5. Evaluate

# plotting the metrics 
plt.plot(d_loss)
plt.plot(d_g_loss)
plt.show()

CGANs — CIFAR10 results

epoch = 1/100, d_loss=0.291, g_loss=4.552 in CGAN_CIFAR10
Conditional Generative Adversarial Nets or CGANs by fernanda rodríguez.
epoch = 100/100, d_loss=0.188, g_loss=7.892 in CGAN_CIFAR10

Train summary

Train summary CCGAN by fernanda rodríguez
Train summary CCGAN by fernanda rodríguez

Github repository

Look the complete training CGAN with CIFAR10 dataset, using Python and Keras/TensorFlow in Jupyter Notebook.

--

--

fernanda rodríguez
fernanda rodríguez

Written by fernanda rodríguez

hi, i’m maría fernanda rodríguez r. multimedia engineer. data scientist. front-end dev. phd candidate: augmented reality + machine learning.

No responses yet