From Deep Mixtures to Deep Quantiles - Part 1 - 2019-02-16

In which we learn everything about $y$ and (ab)use Keras to optimize anything

What is the error of your latest deep learning regression model? Well, since you had a well-defined objective function - say the MSE - you already know the answer. But you are asking yourself (or, more likely, your boss is asking you): can we do better?

The answer depends on whether the error is due to model errors, a.k.a accuracy, or intrinsic randomness in the target variable, a.k.a. precision.

And if the un-predictable randomness dominates the error, there is hardly anything we can do to improve on it. Or can we? What if instead of learning to predict a single value, we could capture the probability distribution, i.e. everything there is to know about the target variable?

This post assumes some familiarity with maximum likelihood estimation and deep learning. I might get round to introducing some foundational concepts in another post.

The trouble with data

Consider the dataset of $n$ pairs $X_i$-$y_i$ below - your prediction $\hat y$ can be very accurate throughout, but can never be precise around $X=0.5$, where the error is dominated by the underlying variable’s random distribution.

To make matters worse, the likelihood $p(y|X)$, a.k.a. the conditional probability density function (p.d.f.), is far from the usual Gaussian - and for some $X$, $y$ is effectively multi-valued. And yet, the Gaussian assumption underpins the concepts of least squares and standard deviation applied in most regression problems.

To summarize, in the general case our conditional p.d.f suffers from being non-Gaussian, sometimes badly so. It can be

  • multi-modal - the most likely values cluster around several means
  • heteroscedastic - the distribution is different at different $X$
  • skewed - the distribution is not symmetric, making e.g. values far above the mean more likely than values far below it.

Yet another universal approximator

The go-to distribution more complex than a Gaussian is - you guessed - a (weighted) mixture of $m$ Gaussians. We can write

$$ p(y|X) = c \sum_{j=1}^{m} w_j(X) \sigma_j(X)^{-2} \exp \left\{ -\frac{\left[y - \mu_j(X)\right]^2}{2 \sigma_j^2(X)} \right\} $$ with constraints $\sum_j w_j = 1$ and $\sigma_j > 0$.

For a given $X$, any distribution can be approximated arbitrarily well by a mixture with enough components - akin to a basis function decomposition. This addresses the multi-modality and skew issues. And by choosing appropriate functional forms for the weights $w_j$, means $\mu_j(X)$ and spreads $\sigma_j(X)$, we should be well on our way to dealing with heteroscedasticity (phew!).

As our go-to functional is a deep neural network, we start by implementing one that can predict $w_j$, $\mu_j$ and $\sigma_j$ simultaneously:

m = 3  # number of mixture components

input_features = Input((1, ), name='X')
intermediate = Dense(32, activation='relu')(input_features)
# maybe we need more layers...
weight = Dense(m, activation='softmax')(intermediate)
mu = Dense(m, activation=None)(intermediate)
sigma = Dense(m, activation='relu')(intermediate)

mdn_model = Model(input_features, [weight, mu, sigma])

The key points to note are:

  • we are making use of “shared layers” before branching off different outputs
  • we choose activations in the output layers depending on constraints we want to impose. For example, the “weight” outputs will always add up to 1.

This model, introduced by Bishop, is known as the Mixture Density Network (MDN):

Mixture Density Network

Mixture Density Network

Beyond loss(y_true, y_pred)

Our optimization objective is to maximize the likelihood of the training data. As is usual, we will instead minimize the sum of negative-log-likelihoods $ \mathcal{L} = - \sum_{i=1}^n \ln p \left( y_i|X_i \right) $ over all examples. Keeping just the general form, we need to minimize

$$ \mathcal{L} = \sum_{i=1}^n \mathcal{L}\left( y_i, \mathbf{w}(X_i), \mathbf{\mu}(X_i), \mathbf{\sigma}(X_i) \right) $$

Keras of very good at optimizing loss functions of the form $ \sum_{i=1}^n \mathcal{L}\left( y_i, \hat y(X_i) \right)$, which does not work out-of-the-box for our network! While anything can be implemented with keras and tensorflow, we have grown increasingly lazy (thanks python and keras!).

So here’s a lazy way: we will need

  • a loss layer, which combines our multiple outputs with the labels $y_i$ into a single loss output $l_i$ per example
  • a dummy loss function which simply returns the mean of predicted loss outputs
  • and finally, a loss model - keras model which takes $X_i$ and $y_i$ as inputs, and returns losses $l_i$ - with a suitable optimizer
# Compute the tensor of losses
def mdn_loss_tensor(y_true, w, mu, sigma):
    inv_sigma_2 = 1 / K.square(sigma + K.epsilon())
    phi = inv_sigma_2 * K.exp(-inv_sigma_2 * K.square(y_true - mu))
    return -K.log(K.sum(w * phi, axis=1) + K.epsilon())

# Loss layer with 4 inputs and 1 output
MDNLossLayer = Lambda(lambda args: mdn_loss_tensor(*args))

# Dummy loss with 2 inputs
def dummy_loss(y_true, y_pred):
    return K.mean(y_pred)

# Now we can define the model:
input_labels = Input((1, ), name='y')
loss_output = MDNLossLayer([input_labels, weight, mu, sigma])

loss_model = Model([input_features, input_labels], loss_output)  # 2 inputs, 1 output
loss_model.compile(optimizer='Adam', loss=dummy_loss)

Note the use of K.epsilon() to ensure that “dangerous” functions never return nan.

We have defined a multi-input, single-output keras model that shares layers with the mdn_model defined above:

Mixture Density Loss Network

Mixture Density Loss Network

Note that we pass the labels $y$ to one of the inputs of the loss model. Instead of explicitly comparing labels $y$ with a model output $\hat y$, we use $y$ to compute a more complex optimization objective. Any training we perform using loss_model.fit() will simultaneously optimize the weights in mdn_model:

loss_model.fit([X_train, y_train], 0 * y_train)

Our “loss model” takes a list of two arrays for the two inputs, and a dummy output array that is ignored by dummy_loss - it just needs to have the correct number of rows. Once the model is trained, we can inspect the predictions:

w_test, mu_test, sigma_test = mdn_model.predict(X_test)

Near the top of this post, I have plotted the predicted means of three components, shaded by their respective mixture weights $w$. The lowest negative log-likelihood is generally achieved in the situation shown in the first plot. However, the model is quite prone to getting stuck in local minima, some of which completely discards one of the components - as show by the plots obtained by re-training the model repeatedly.

Looking at the mean $\langle y \rangle = \sum_j w_j \mu_j$ plotted above, we can see that, while accurate, it conveys little about our data! The real strength of the MDN lies in it capturing the *full conditional p.d.f.” - allowing us to compute arbitrary summary statistics of the data.

Implementation

The complete code is available as part of the deepquantiles package. It includes a convenient MixtureDensityRegressor class, all the losses, and an example Jupyter notebook.

What’s next

A trained Mixture Density Network can be used generate samples from the learned distribution. I have included an implementation in the deepquantiles package - efficient sampling from a heteroscedastic mixture model certainly warrants another post!

Mixture Models are certainly a popular, but not the only possible approach to capturing a complex conditional probability distribution. In future parts of this series, I will describe an alternative method using Deep Quantile Regression.

References

  1. Original paper: Mixture Density Networks, Christopher M. Bishop, NCRG/94/004 (1994)

  2. A probabilistic approach combining Tensorflow with Edward: http://cbonnett.github.io/MDN_EDWARD_KERAS_TF.html

  3. There is a number of alternative implementations in “low-level” Tensorflow, the more interesting ones making use of the relatively recent tensorflow_probability - worth checking out!

  4. Slides from my talk at the TensorFlow London meetup about Smart Energy and MDNs: Beyond MSE: Forecasting probability distributions with Tensorflow

comments powered by Disqus