From Deep Mixtures to Deep Quantiles - Part 2 - 2019-02-24

In which we (attempt to) speed up sampling from a mixture density model

In Part 1 of this series of posts, we trained a Mixture Density Network to capture a heteroscedastic conditional probability distribution.

After predicting the parameters $(w_{ij}, \mu_{ij}, \sigma_{ij})$ of the $m$ mixture components, we would like to generate some - and in some cases many - samples from the learned distribution.

A fully vectorized implementation for the problem at hand required some thought, so I decided to write it up for future reference.

Being able to sample is important for two reasons:

  • first, it gives us access to unlimited “synthetic” data,
  • and second, any statistic you need can be computed from a big enough sample.

Imagine we want to sample num_samples random values from a mixture of num_components Gaussians, at each of num_rows values of $X_i$.

While numpy.random provides functions for efficient sampling from a number of probability distributions, we are trying to sample from num_rows different distributions simultaneously.

Let us create some test data of appropriate shape (you can find the complete code here):

def create_data(num_rows=100, num_components=5):
    w = np.random.rand(num_rows, num_components)
    w = np.exp(w) / np.exp(w).sum(axis=1, keepdims=True)
    mu = np.random.rand(num_rows, num_components)
    sigma = np.random.rand(num_rows, num_components)
    return w, mu, sigma

w, mu, sigma = create_data()

Sampling from a single Gaussian Mixture (at row i) is a two-step process:

  1. Choose a random number, choice, between 0 and num_components - 1, with probabilities w[i, :]
  2. Sample a random number from the chosen mixture component

This can be implemented as follows:

# Implementation 0 - two nested loops
samples = np.empty((num_rows, num_samples))
for i in range(num_rows):
    _w, _mu, _sigma = w[i, :], mu[i, :], sigma[i, :]
    for j in range(num_samples):
        choice = np.random.choice(num_components, p=_w)
        samples[i, j] = _mu[choice] + _sigma[choice] * np.random.randn()

Looping through numpy arrays is never a good thing, and it is easy to spot that both np.random.choice and np.random.randn can generate an array of random values in a single call. Using numpy’s fancy indexing, we can get rid of the for loop:

# Implementation 1 - a single loop over rows
samples = np.empty((num_rows, num_samples))
for i in range(num_rows):
    _w, _mu, _sigma = w[i, :], mu[i, :], sigma[i, :]
        choice = np.random.choice(num_components, size=num_samples, p=_w)
        samples[i] = _mu[choice] + _sigma[choice] * np.random.randn(num_samples)

For reasonable values (num_rows, num_samples, num_components = 1000, 1000, 5), the speed gain is massive:

sampling time
Implementation 0 21s
Implementation 1 0.13s
Implementation 2 0.12s

However, getting rid of the inner loop is harder. The only way I could come up with is as follows:

# Implementation 2 - no python loops
## First, sample which mixture component to use for each (row, sample)
choices = np.random.rand(num_rows, 1, num_samples)
thresholds = w.cumsum(axis=1)[:, :, np.newaxis]
col_idx = (choices < thresholds).argmax(axis=1)
_, row_idx = np.meshgrid(np.arange(num_samples), np.arange(num_rows))

## Second, compute the samples
mu = mu[row_idx, col_idx]
sigma = sigma[row_idx, col_idx]
samples = mu + sigma * np.random.randn(num_rows, num_samples)

The approach hinges on numpy array broadcasting. Note the shapes of the arrays we construct to choose a mixture component for each sample:

array shape
choices (num_rows, 1, num_samples)
thresholds (num_rows, num_components, 1)
(choices < thresholds) (num_rows, num_components, num_samples)
col_idx (num_rows, num_samples)

This way, we can index the right mu and sigma for all samples simultaneously. All the indexing doesn’t come for free, and in fact the speed difference to the previous (looped) implementation is not as clear-cut:

Vectorizing MDN sampling

Vectorizing MDN sampling

When sampling a large number of points for each row, the looped approach is in fact marginally faster! However, vectorization comes into its own for a large number of rows, where my final implementation’s speed depends only on the total number of points samples (i.e. num_rows * num_samples), as evidenced by the straight-line contours in the log-log plot above.

Can we make it faster still? I suspect so. Is the speed-up significant in practice? Depends. Did solving this problem make me revisit vectorization and broadcasting? Definitely.

References

  1. Array broadcasting in numpy: https://docs.scipy.org/doc/numpy/user/theory.broadcasting.html
  2. Implicit array expansion in MATLAB: https://uk.mathworks.com/help/matlab/ref/bsxfun.html
  3. My code for this post: See https://github.com/ig248/deepquantiles/blob/master/notebooks/mdn_sampling.ipynb
  4. My implementation of the MixtureDensityRegressor, including sampling methods: https://github.com/ig248/deepquantiles/blob/master/deepquantiles/regressors/mdn.py
  5. Box–Muller transform for generating normal random variables: https://en.wikipedia.org/wiki/Box%E2%80%93Muller_transform
  6. Marsaglia polar method for generating normal random variables (as used in numpy.random): https://en.wikipedia.org/wiki/Marsaglia_polar_method
  7. Sampling discrete choices: https://en.wikipedia.org/wiki/Pseudo-random_number_sampling#Finite_discrete_distributions
comments powered by Disqus