GANs — Deep Convolutional GANs with MNIST (Part 3)

fernanda rodríguez
3 min readJun 14, 2020

--

Brief theoretical introduction to Deep Convolutional Generative Adversarial Networks or DCGANs and practical implementation using Python and Keras/TensorFlow in Jupyter Notebook.

Deep Convolutional GANs with MNIST by mafda
Deep Convolutional GANs or DCGANs with MNIST by fernanda rodríguez

In this article, you will find:

  • Research paper,
  • Definition, network design, and cost function, and
  • Training DCGANs with MNIST dataset using Python and Keras/TensorFlow in Jupyter Notebook.

Research Paper

Radford, A., Metz, L., & Chintala, S. (2015). Unsupervised representation learning with deep convolutional generative adversarial networks. arXiv preprint arXiv:1511.06434.

Deep Convolutional Generative Adversarial Networks — DCGANs

The difference between the simple GAN and the DCGAN, is the generator of the DCGAN uses the transposed convolution (Fractionally-strided convolution or Deconvolution) technique to perform up-sampling of 2D image size.

DCGAN are mainly composes of:

  • Convolution layers without max pooling or fully connected layers.
  • It uses convolutional stride and transposed convolution for the downsampling and the upsampling.

Read more about GANs:

Read more about Convolutional neural networks — CNN:

Network design

Deep Convolutional Generative Adversarial Networks or DCGAN by fernanda rodríguez.
Deep Convolutional Generative Adversarial Networks Architecture or DCGAN by fernanda rodríguez.

x is the real data and z is the latent space.

Cost function

Cost function DCGANs by fernanda rodríguez.

Training DCGANs

  1. Data: MNIST dataset
(X_train, y_train), (X_test, y_test) = mnist.load_data()

2. Model:

  • Generator
generator = Sequential()# FC: 7x7x256
generator.add(Dense(7*7*128, input_shape=(latent_dim,), kernel_initializer=init))
generator.add(Reshape((7, 7, 128)))
# Conv 1: 14x14x128
generator.add(Conv2DTranspose(64, kernel_size=3, strides=2, padding='same'))
generator.add(BatchNormalization(momentum=0.8))
generator.add(ReLU(0.2))
# Conv 2
...
# Conv 4: 28x28x1
generator.add(Conv2DTranspose(1, kernel_size=3, strides=2, padding='same', activation='tanh'))
  • Discriminator
# Discriminator network
discriminator = Sequential()

# Conv 1: 14x14x64
discriminator.add(Conv2D(32, kernel_size=3, strides=2, padding='same', input_shape=(28, 28, 1), kernel_initializer=init))
discriminator.add(LeakyReLU(0.2))
# Conv 2 and 3
...
# FC
discriminator.add(Flatten())

# Output
discriminator.add(Dense(1, activation='sigmoid'))

3. Compile

optimizer = Adam(lr=0.0002, beta_1=0.5)discriminator.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['binary_accuracy'])discriminator.trainable = Falsed_g = Sequential()
d_g.add(generator)
d_g.add(discriminator)
d_g.compile(optimizer=optimizer, loss='binary_crossentropy', metrics=['binary_accuracy'])

4. Fit

# Train Discriminator weights
discriminator.trainable = True

# Real samples
X_batch = X_train[i*batch_size:(i+1)*batch_size]
d_loss_real = discriminator.train_on_batch(x=X_batch, y=real * (1 - smooth))

# Fake Samples
z = np.random.normal(loc=0, scale=1, size=(batch_size, latent_dim))
X_fake = generator.predict_on_batch(z)
d_loss_fake = discriminator.train_on_batch(x=X_fake, y=fake)

# Discriminator loss
d_loss_batch = 0.5 * (d_loss_real[0] + d_loss_fake[0])

# Train Generator weights
discriminator.trainable = False
d_g_loss_batch = d_g.train_on_batch(x=z, y=real)

5. Evaluate

# plotting the metrics 
plt.plot(d_loss)
plt.plot(d_g_loss)
plt.show()

DCGANs — MNIST results

epoch = 1/100, d_loss=1.508, g_loss=10.259 in DCGAN_MNIST
Deep Convolutional GANs with MNIST by mafda
epoch = 100/100, d_loss=0.252, g_loss=4.981 in DCGAN_MNIST

Train summary

Train summary DCGAN by fernanda rodríguez.
Train summary DCGAN by fernanda rodríguez.

Github repository

Look the complete training DCGAN with MNIST dataset, using Python and Keras/TensorFlow in Jupyter Notebook.

Sign up to discover human stories that deepen your understanding of the world.

Free

Distraction-free reading. No ads.

Organize your knowledge with lists and highlights.

Tell your story. Find your audience.

Membership

Read member-only stories

Support writers you read most

Earn money for your writing

Listen to audio narrations

Read offline with the Medium app

--

--

fernanda rodríguez
fernanda rodríguez

Written by fernanda rodríguez

hi, i’m maría fernanda rodríguez r. multimedia engineer. data scientist. front-end dev. phd candidate: augmented reality + machine learning.

Responses (1)

Write a response