GANs — Generative Adversarial Network with MNIST (Part 2)

fernanda rodríguez
3 min readJun 12, 2020

A brief theoretical introduction to Generative Adversarial Networks or GANs and practical implementation using Python and Keras/TensorFlow in Jupyter Notebook.

Generative Adversarial Networks or GANs with MNIST by mafda
Generative Adversarial Networks or GANs with MNIST by fernanda rodríguez.

In this article, you will find:

  • Research paper,
  • Definition, network design, and cost function, and
  • Training GANs with MNIST dataset using Python and Keras/TensorFlow in Jupyter Notebook.

Research Paper

Goodfellow, I.J., Pouget-Abadie, J., Mirza, M., Xu, B., Warde-Farley, D., Ozair, S., Courville, A.C., & Bengio, Y. (2014). Generative Adversarial Nets. ArXiv, abs/1406.2661.

Generative Adversarial Networks — GANs

Generative Adversarial Networks or GANs is a framework composed of two models, represented by neural networks:

  • The first model is called a Generator and it aims to generate new data similar to the expected one.
  • The second model is named the Discriminator and it aims to recognize if an input data is ‘real’ — belongs to the original dataset — or if it is ‘fake’ — generated by a forger.

Network design

Generative Adversarial Network Architecture by mafda
Generative Adversarial Network Architecture by fernanda rodríguez.

x is the real data and z is the latent space.

Cost function

Cost function GANs by mafda
Cost function GANs by fernanda rodríguez.

Read more about GANs here.

Training GANs

  1. Data: MNIST dataset
(X_train, y_train), (X_test, y_test) = mnist.load_data()

2. Model:

  • Generator
generator = Sequential()

# Input layer and hidden layer 1
generator.add(Dense(128, input_shape=(latent_dim,), kernel_initializer=init))
generator.add(LeakyReLU(alpha=0.2))
generator.add(BatchNormalization(momentum=0.8))

# Hidden layer 2 and 3
...

# Output layer
generator.add(Dense(img_dim, activation='tanh'))
  • Discriminator
# Discriminator network 
discriminator = Sequential()
# Input layer and hidden layer 1
discriminator.add(Dense(128, input_shape=(img_dim,), kernel_initializer=init))
discriminator.add(LeakyReLU(alpha=0.2))
# Hidden layer 2 and 3
...
# Output layer
discriminator.add(Dense(1, activation='sigmoid'))

3. Compile

optimizer = Adam(lr=0.0002, beta_1=0.5)discriminator.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['binary_accuracy'])discriminator.trainable = False

d_g = Sequential()
d_g.add(generator)
d_g.add(discriminator)
d_g.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['binary_accuracy'])

4. Fit

# Train Discriminator weights
discriminator.trainable = True

# Real samples
X_batch = X_train[i*batch_size:(i+1)*batch_size]
d_loss_real = discriminator.train_on_batch(x=X_batch, y=real * (1 - smooth))

# Fake Samples
z = np.random.normal(loc=0, scale=1, size=(batch_size, latent_dim))
X_fake = generator.predict_on_batch(z)
d_loss_fake = discriminator.train_on_batch(x=X_fake, y=fake)

# Discriminator loss
d_loss_batch = 0.5 * (d_loss_real[0] + d_loss_fake[0])

# Train Generator weights
discriminator.trainable = False
d_g_loss_batch = d_g.train_on_batch(x=z, y=real)

5. Evaluate

# plotting the metrics 
plt.plot(d_loss)
plt.plot(d_g_loss)
plt.show()

GANs — MNIST results

epoch = 1/100, d_loss=0.819, g_loss=1.362 in 01_GAN_MNIST.
Generative Adversarial Networks or GANs with MNIST by mafda
epoch = 100/100, d_loss=0.542, g_loss=1.658 in 01_GAN_MNIST.

Train summary

Train summary GANs by fernanda rodríguez.
Train summary GANs by fernanda rodríguez.

Github repository

Look the complete training GAN with MNIST dataset, using Python and Keras/TensorFlow in Jupyter Notebook.

Sign up to discover human stories that deepen your understanding of the world.

Free

Distraction-free reading. No ads.

Organize your knowledge with lists and highlights.

Tell your story. Find your audience.

Membership

Read member-only stories

Support writers you read most

Earn money for your writing

Listen to audio narrations

Read offline with the Medium app

fernanda rodríguez
fernanda rodríguez

Written by fernanda rodríguez

hi, i’m maría fernanda rodríguez r. multimedia engineer. data scientist. front-end dev. phd candidate: augmented reality + machine learning.

No responses yet

Write a response