Weidi Xu

Many machine learning models involve discrete variables. For instance, the latent variable of generative models are sometimes discrete. And hard version of attention mechanism takes the hidden attention at each step as discrete stochastic variable. In this case, the backpropagation is not directly applicable to the stochastic networks, which makes it very difficult to train such networks. This blog summarizes several methods to deal with the gradient propagation problem, from the straight-through estimator (2013) to Gumbel-softmax (2016).

Problem defination

Consider a simple stochastic model with discrete random variable \(x\) whose probability is given by \(p_\theta(x)\), and a loss function \(f(x)\). The objective of training is to minimize the expected lost \(L(\theta)={E}_{p_\theta(x)}(f(x))\). There are roughly three kinds of methods to deal with this issue. They are introduced in the following.

Straight-through estimator

This method is proposed by Bengio, 2013. The idea behind straight-through estimator is to backpropagate through the thresholding function as if it were the identity function. When the stochastic variables are binary, the estimator is simply \(f(s)p'(s)\), where \(s\) is sampled by \(s=1_{z_i>p_\theta(x)},z\sim[0,1]\). This estimator is biased but has a low variance. For more information, I direct you to here.

Likelihood-ratio estimator

Likelihood-ratio estimator (also known as score-function based estimator or REINFORCE) plays an important part in dealing with discrete variables. The objective function implies that the distribution of latent variable \(x\) can be regarded as a policy network in reinforcement learning problems with loss function corresponding to the reward signal. Hence the policy gradient method, especially REINFORECE algorithm, was a popular method to be adopted. The estimator can be given as:

\[E_{p_\theta(x)}f(x) \nabla_\theta \log p_\theta(x).\]

However, it is well known that this method has a problem of high variance. Hence many variance reduction techniques are put forward.

Gumbel-softmax

The Gumbel softmax was recently used in Jang, 2016 and Maddison, 2016 to propagate through discrete variables. It is a ‘reparameterization trick’ for the categorical distribution. More specifically, it is actually a re-parameterization trick for a distribution that we can smoothly deform into the categorical distribution. Refer to Gumbel-softmax tutorial.