avatarJonathan Hui

Summary

The web content discusses Unrolled GAN, a technique designed to mitigate mode collapse in Generative Adversarial Networks (GANs) by allowing the generator to anticipate the discriminator's future updates over a series of k steps.

Abstract

Unrolled GAN is an advanced GAN architecture that addresses the common issue of mode collapse by enabling the generator to "look ahead" and predict the discriminator's optimization steps. By unrolling k steps, the generator can prepare for the discriminator's counteractions, leading to a more stable training process and reduced likelihood of overfitting to a specific discriminator. This approach improves the diversity and quality of generated samples, as demonstrated in experiments with Gaussian distributions and RNN generators. The article provides an in-depth explanation of the training process for both the discriminator and generator, including the use of graph replacement in TensorFlow to implement the unrolled steps. It also references the original Unrolled GAN paper and its TensorFlow implementation for further reading.

Opinions

  • The author believes that Unrolled GANs are effective in reducing mode collapse and improving model stability.
  • The technique of unrolling steps is seen as a strategic move in the "game" between the generator and discriminator, akin to anticipating an opponent's actions in a real game.
  • The article suggests that the number of unrolled steps (typically 5 to 10) is crucial for good model performance.
  • The author emphasizes the simplicity of the Unrolled GAN implementation, despite its significant impact on GAN training.
  • Experiments comparing Unrolled GANs to standard GANs are presented to illustrate the superior performance of Unrolled GANs in generating diverse and high-quality samples.
  • The article encourages readers to explore additional resources for a comprehensive understanding of GAN improvements and the full series of articles on GANs.

GAN — Unrolled GAN (how to reduce mode collapse)

Photo by Ethan Hu

Intuition: In any game, you look ahead for the next few moves of your opponent and prepare your next move accordingly. In Unrolled GAN, we give an opportunity for the generator to unroll k steps on how the discriminator may optimize itself. Then we update the generator using backpropagation with the cost calculated in the final k step. The lookahead discourages the generator to exploit local optimal that easily counteract by the discriminator. Otherwise, the model will oscillate and even become unstable. Unrolled GAN lowers the chance that the generator is overfitted for a specific discriminator. This lessens mode collapse and improves stability.

This article is part of the series on GAN. Since mode collapse is common, we spend some time to explore Unrolled GAN to see how mode collapse may be addressed.

Discriminator training

In GAN, we compute the cost function and use backpropagation to fit the model parameters of the discriminator D and the generator G.

We redraw the diagram below to emphasize the model parameters θ. The red arrows show how we backpropagate the cost function f to fit the model parameters.

Here are the cost function and the gradient descent. (we use a simple gradient descent for the purpose of the illustration)

In the diagram below, we add the SGD (gradient descent formula) to explicitly define how the discriminator parameters are calculated.

In Unrolled GAN, we train the discriminator exactly the same way as GAN.

Generator training

Unrolled GAN plays k steps to learn how the discriminator may optimize itself for the specific generator. In general, we use 5 to 10 unrolled steps which demonstrates pretty good model performance. The diagram below unrolls the process 3 times.

The cost function is based on the latest discriminator’s model parameters while the generator’s model parameters remain the same.

At each step, we apply the gradient descent to optimize a new model for the discriminator.

But as mentioned before, we only use the first step to update the discriminator. The unrolling is used by the generator to predict moves but not used in the discriminator optimization. Otherwise, we may overfit the discriminator for a specific generator.

For the generator, we backpropagate the gradient throughout all k steps. This is very similar to how LSTM is unrolled and how gradients are backpropagated. Since we have k unrolled steps, the generator also accumulates the parameter changes k times (one for each step) as shown above.

To summarize, the Unrolled GAN uses the cost function calculated in the last step to perform the backpropagation for the generator while the discriminator uses the first step only.

Coding

The implementation of Unrolled GAN can be found from here. Actually, it is pretty simple. The core logic for unroll k step is simply:

for i in range(params['unrolling_steps'] - 1):
    cur_update_dict = graph_replace(update_dict, cur_update_dict)
    unrolled_loss = graph_replace(loss, cur_update_dict)

which the graph_replace loads the discriminator with the latest discriminator model from the last step. Here is the core logic in building the computation graph in TensorFlow.

with slim.arg_scope([slim.fully_connected],   
     weights_initializer=tf.orthogonal_initializer(gain=1.4)):
    samples = generator(noise, output_dim=params['x_dim'])
    real_score = discriminator(data)
    fake_score = discriminator(samples, reuse=True)

loss = tf.reduce_mean(
          tf.nn.sigmoid_cross_entropy_with_logits(logits=real_score, 
             labels=tf.ones_like(real_score)) +
          tf.nn.sigmoid_cross_entropy_with_logits(logits=fake_score, 
             labels=tf.zeros_like(fake_score)))

...
updates = d_opt.get_updates(disc_vars, [], loss)
d_train_op = tf.group(*updates, name="d_train_op")
...
# Get dictionary mapping from variables to their update value
# after one optimization step
update_dict = extract_update_dict(updates)
cur_update_dict = update_dict
for i in range(params['unrolling_steps'] - 1):
    cur_update_dict = graph_replace(update_dict, cur_update_dict)
    unrolled_loss = graph_replace(loss, cur_update_dict)
...
g_train_op = g_train_opt.minimize(-unrolled_loss, var_list=gen_vars)
...
f, _, _ = sess.run([[loss, unrolled_loss], g_train_op, d_train_op])

Experiments

In the experiment below, we start with a toy dataset contains a mixture of 8 Gaussian distributions. Provided with a less complex generator, the GAN in the second row manages to generate good data quality but fail to achieve diversity. The mode collapses. Applying Unrolled GAN, it discovers all 8 modes with high quality (the first row).

Source

RNN generator is particular vulnerable to mode collapse. The Unrolled GAN (the first row) manages to discover all 10 modes while a regular GAN model collapses (the second row).

Source

Further readings

If you want to learn more in improving GANs:

A full listing of all articles in this series:

Reference

Unrolled GAN paper

Code implementation in TensorFlow

Machine Learning
Deep Learning
Data Science
Computer Vision
Artificial Intelligence
Recommended from ReadMedium