From Deep Mixtures to Deep Quantiles - Part 3 - 2019-05-18

In which we struggle to come up with acronyms, and end up with DQQRNs

Least squares regression is taught early on in every science course. The mean squared error (MSE) arises naturally from minimizing the negative-log-likelihood under an assumed Gaussian distribution of outcomes - a special case of the Gaussian mixture density with $m=1$ components (and $\sigma=\textrm{const}$). Less frequently, however, engineers, statisticians, and other like-minded individuals need regression models less sensitive to outlier - and replace MSE with the mean absolute error (MAE). Statistically speaking, minimizing the MSE allows us to learn the (conditional) mean of our data, while minimizing the MAE results in the median. The median estimate is empirical and does not rely on assumptions about the underlying distribution. The median is a special quantile - can we generalize our MAE loss function to learn other quantiles in a similar way?

In this post, we introduce the Quantile Loss and use Deep Quantile Regression to get an alternative view on the uncertainty in the target variable - and also propose some (to my knowledge) new approaches as an unfinished experiment for the curious reader.

Quantiles - a probabilistic perspective

Modelling uncertainty within a machine learning task requires a suitable representation of the underlying probability distribution. In Part 1 of this series of posts, we modelled the probability density function (p.d.f.) of the target distribution using a sum of one or more Gaussians, thus obtaining a learnable parametric representation. This Gaussian Mixture model allows practical sampling by using standard algorithms for efficiently generating normal and discrete random variables.

Another common approach is to approximate the cumulative distribution function (c.d.f.), e.g. as a piecewise linear function. In fact, knowing the inverse of the c.d.f. - a.k.a. the quantile function - is sufficient for very efficient inverse transform sampling.

The quantile loss

Given data points $y_i$, we can compute the median $y_{q=0.5}$ as the value which minimizes the MAE loss function, i.e. $y_{q=0.5} = \mathrm{argmin}_{\hat y} L_{q=0.5}(\hat y)$, with the loss function (summed over all points)

$$L_{q=0.5}(\hat y) = \sum_i l_{q=0.5, i}(\hat y) = \sum_i \|\hat y - y_i\|$$

It turns out that we can compute quantiles for any $q$ with a slightly modified loss function, known as the quantile, pinball or tilted loss function - a “tilted” version of the V-shaped MAE loss:

$$ l_{q=0.5, i}(\hat y) = q \times \|\hat y - y_i\| \quad \text{ for } \hat y - y_i \leq 0 $$ $$ l_{q=0.5, i}(\hat y) = (1-q) \times \|\hat y - y_i\| \quad \text{ for } \hat y - y_i \geq 0 $$

The loss looks something like this:

Quantile loss (a.k.a. tilted loss)

Quantile loss (a.k.a. tilted loss)

Intuitively, for $q > 0.5$, points with values $y_i$ below $\hat y$ (right hand side of the plot) will contribute proportionally less, and the optimal value of $\hat y$ will partition the data in the $q$ : $1-q$ ratio.

Non-linear Quantile regression

Now that we have established how to find quantiles through minimizing a differentiable loss function, the world of TensorFlow is our oyster. First, we expect our predictions to be conditional on some features $x$, and our data is of the form $(x_i, y_i)$.

We define a parametric prediction function $f(x)$, and minimize the quantile loss until we learn to predict the chosen quantile. The function $f$ can be any of the commonly used model: a linear function, an ensemble of decision trees, a different ensemble of decision trees, or a feed-forward neural network.

In fact, the top row of images in this post shows the result of learning 7 different conditional quantiles (including the median) with each of the latter three algorithms, on the “pathological” S-shaped data set.

Deep quantile regression with layer sharing

As usual with deep learning, we can trade off prediction quality and overfitting through choosing from a wide space of hyperparameters, not least model architectures and layer counts. Above, we trained a separate model for each quantile - this can seem wasteful given that we expect the problems of predicting different quantiles to be closely related. In addition, we can see that learning quantiles independently can lead to the un-intended consequence of quantiles “crossing over” in an unrealistic way.

Luckily, we can impose a “smoothness” prior by training a single model to predict several quantiles simultaneously.

This approach showed some promise in experiments with synthetic data, and, more importantly, all of deep quantile regression networks (DQRNs?) converged more reliably than mixture density networks (MDNs) - some example plots can be seen below.

Regression over all quantiles

However, one key problem remains - we still have to choose a fixed, finite number of quantiles before starting the training process. To generate samples from the learned distribution using inverse transform sampling, we still need to perform piecewise linear interpolation for intermediate $q$-values. What if instead we could train a model to predict the location of the quantile given $x$ and $q$? Here’s what I came up with, I call it the “Inverse CDF Regression” or the “Deep Quantile Q Regression Network”:

Remember that the quantile loss function above depends on $q$. To make it work, we need two tricks:

  • the “loss model trick”: just like we did with MDNs, we define a secondary “loss” model which has inputs $(x, q, y)$, which leaves us with the freedom to define arbitrary loss functions
  • the “data augmentation trick”: we generate batches of $q$ randomly at training time, allowing the model to learn a reasonable approximation for the inverse cumulative distribution function.

Some results

Training this model is generally quite easy, but training it well can require additional hyper-parameter tuning. Combining two very different features ($x$ and $q$) can lead to mis-matched gradient magnituded and other convergence problems, so a few tweaks helped along the way:

  • including batch normalization
  • over-sampling certain regions of $q$ values, especially when learning distributions with very long tails
  • choosing suitable activation functions in the final layer - a simple ReLU can struggle to output a wide range of values for long-tailed distributions, and something like $\exp$ and $\sinh$ works better - but seems like cheating and tuning too much by hand.

However, with suitable parameters, convergence is much more robust than that of Mixture Density Networks (MDNs). In addition, inverse sampling works very well indeed - all it takes is a batch of uniform random numbers ($q$ values) and a single feed-forward pass.

Sampling from a DQQRN

Sampling from a DQQRN

Here’s another example of a “difficult” distribution:

and samples from the learned distributions using both MDNs and different variants of the DQRNs:

Training samples

Training samples

While the DQQRN produces a decent “smooth” approximation, it struggles with the long tail.

Conclusion

This series of posts arose from working on several practical ML problems that required a better understanding of the uncertainty in model predictions. Some of the approaches presented here are still work in progress, but could be applicable to a wide range of problems.

So first of all, when should we bother with the complexity of MDNs and/or Quantile Regressions? Some or all of the following are probably good indicators:

  • we are dealing with a regression problem
  • the “human baseline” is poor or non-existant
  • the prediction uncertainty is dominated by the stochastic nature of the target variable, rather than model inaccuracy
  • the target distributions are heteroscedastic - i.e. vary depending on input features
  • the target distributions are non-Gaussian
  • we need to generate samples for modelling “what if” scenarios

How do the different approaches stack up against one another? The table below summarizes my limited experience:

Mixture Density Network (MDN)   Quantile Regression
standard implementations try mine? sklearn/catboost/various stats tools
generalizes to $n$ dimensions ✔️
convergance/ease of training tricky ❌ depends; v.good for simple cases ✔️
fast sampling ✔️ ✔️

I am hoping that, mathematical curiosity aside, these approaches will warrant some practical interest and, more importantly, regression modelling “beyond MSE” will become more common-place.

References

  1. You may find the related late (or early?) Christmas post interesting

  2. Deep quantile regression for specified quantiles, including sampling methods: a keras implementation

  3. Deep Quantile Regression over $q$ (learning the quantile function): a keras implementation

  4. Notebooks and experiments for this post

comments powered by Disqus