GANs — Least Squares GANs with MNIST (Part 7)

fernanda rodríguez
3 min readJun 16, 2020

--

Brief theoretical introduction to Least Squares Generative Adversarial Nets or LSGANs and practical implementation using Python and Keras/TensorFlow in Jupyter Notebook.

Least Squares Generative Adversarial Nets or LSGANs by fernanda rodríguez
Least Squares Generative Adversarial Nets or LSGANs by fernanda rodríguez

In this article, you will find:

  • Research paper,
  • Definition, network design, and cost function, and
  • Training LSGANs with MNIST dataset using Python and Keras/TensorFlow in Jupyter Notebook.

Research Paper

Mao, X., Li, Q., Xie, H., Lau, R.Y., Wang, Z., & Smolley, S.P. (2017). Least Squares Generative Adversarial Networks. 2017 IEEE International Conference on Computer Vision (ICCV), 2813–2821.

Least Squares Generative Adversarial Networks — LSGANs

Least Squares Generative Adversarial Networks (LSGANs) adopt the least squares loss function for the Discriminator 𝐷.

The least squares loss function is able to move the fake samples toward the decision boundary, because the least squares loss function penalizes samples that lie in a long way on the correct side of the decision boundary.

Another benefit of LSGANs is the improved stability of learning process.

Read more about GANs:

Network design

Least Squares Generative Adversarial Networks — LSGANs Architecture by fernanda rodríguez
Least Squares Generative Adversarial Networks — LSGANs Architecture by fernanda rodríguez

x is the real data and z is the latent space.

Cost function

Cost function LSGAN by fernanda rodríguez
Cost function LSGAN by fernanda rodríguez

where a and b are the labels for fake data and real data, respectively, and denotes the value that 𝐺 wants 𝐷 to believe for fake data.

Training LSGANs

  1. Data: MNIST dataset
(X_train, y_train), (X_test, y_test) = mnist.load_data()

2. Model:

  • Generator
# Generator network
generator = Sequential()
# Input layer and hidden layer 1
generator.add(Dense(128, input_shape=(latent_dim,), kernel_initializer=init))
generator.add(LeakyReLU(alpha=0.2))
generator.add(BatchNormalization(momentum=0.8))

# Hidden layer 2 and layer 3
...

# Output layer
generator.add(Dense(img_dim, activation='tanh'))
  • Discriminator
# Discriminator network
discriminator = Sequential()

# Hidden layer 1
discriminator.add(Dense(128, input_shape=(img_dim,), kernel_initializer=init))
discriminator.add(LeakyReLU(alpha=0.2))

# Hidden layer 2 and layer 3
...

# Output layer
discriminator.add(Dense(1))

3. Compile

optimizer = Adam(lr=0.0002, beta_1=0.5)

discriminator.compile(optimizer=optimizer, loss='mse', metrics=['binary_accuracy'])
discriminator.trainable = False

z = Input(shape=(latent_dim,))
img = generator(z)
decision = discriminator(img)
d_g = Model(inputs=z, outputs=decision)

# Optimize w.r.t. MSE loss instead of crossentropy
d_g.compile(optimizer=optimizer, loss='mse', metrics=['binary_accuracy'])

4. Fit

# Train Discriminator weights
discriminator.trainable = True

# Real samples
X_real = X_train[i*batch_size//2:(i+1)*batch_size//2]

# Fake Samples
z = np.random.normal(loc=0, scale=1, size=(batch_size//2, latent_dim))
X_fake = generator.predict_on_batch(z)

# Discriminator loss
d_loss_batch = discriminator.train_on_batch(
x=np.concatenate((X_fake, X_real), axis=0),
y=np.concatenate((a, b), axis=0)
)

# Train Generator weights
discriminator.trainable = False

z = np.random.normal(loc=0, scale=1, size=(batch_size, latent_dim))
d_g_loss_batch = d_g.train_on_batch(x = z, y = c)

5. Evaluate

# plotting the metrics 
plt.plot(d_loss)
plt.plot(d_g_loss)
plt.show()

LSGANs — MNIST results

epoch = 1/100, d_loss=0.159, g_loss=0.474 in LSGAN_MNIST
Least Squares Generative Adversarial Nets or LSGANs by fernanda rodríguez
epoch = 100/100, d_loss=0.124, g_loss=0.620 in LSGAN_MNIST

Train summary

Train summary LSGAN by fernanda rodríguez
Train summary LSGAN by fernanda rodríguez

Github repository

Look the complete training LSGAN with MNIST dataset, using Python and Keras/TensorFlow in Jupyter Notebook.

--

--

fernanda rodríguez

hi, i’m maría fernanda rodríguez r. multimedia engineer. data scientist. front-end dev. phd candidate: augmented reality + machine learning.