Is there a way to parallelize scipy.integrate.quad over a set of values?

24 Views Asked by At

I am computing the posterior probability of some parameters θ with a likelihood function that is slow to compute, and I wonder if there is a way to speed things up.

The slow part of my code is the computation of the chi squared that presents an integral that needs to be computed for every sample in z_sn.

import numpy as np
import scipy.constants as cte
from scipy.integrate import quad

def luminosity_integrand(z, omgM):
    Ez = np.sqrt((1 - omgM) + omgM * np.power(1 + z, 3))
    return 1. / Ez

def luminosity_distance(z, h, omgM):
    integral, _ = quad(luminosity_integrand, 0, z, epsrel=1e-8, args=(omgM))
    return (cte.c / 10. ** 5) / h * (1 + z) * integral

def distance_modulus(z, h, omgM):
    return 5. * np.log10(luminosity_distance(z, h, omgM)) + 25.

def chisq_sn(h, omgM):
    m_model = np.array([distance_modulus(z, h, omgM) for z in z_sn])
    diffs = m_obs-m_model
 
    maha_distances = np.dot(np.dot(diffs, inv_cov_plus), diffs)  # mahalanobis distance
    return maha_distances

I have used list comprehension, but I don't know if it's the fastest route. I have also seen some comments about using Cython, but I have never used it. I am using emcee for the computation of the posterior, if it matters.

1

There are 1 best solutions below

0
Matt Haberland On

If you don't mind using a private function (which will not necessarily be available in future versions of SciPy and is not officially supported), there is a vectorized integrator in SciPy 1.12.0 that gives a 10x speedup.

import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import quad
from scipy.integrate._tanhsinh import _tanhsinh

def luminosity_integrand(z, omgM):
    Ez = np.sqrt((1 - omgM) + omgM * np.power(1 + z, 3))
    return 1. / Ez

z_sn = np.linspace(0, 10, 1000)
omgM = 0.5

%timeit [quad(luminosity_integrand, 0, z, args=(omgM))[0] for z in z_sn]
# 146 ms ± 23.3 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

%timeit _tanhsinh(luminosity_integrand, 0, z_sn, args=(omgM,)).integral
# 10.4 ms ± 1.35 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)

luminosity1 = [quad(luminosity_integrand, 0, z, args=(omgM))[0] for z in z_sn]
luminosity2 = _tanhsinh(luminosity_integrand, 0, z_sn, args=(omgM,)).integral

np.testing.assert_allclose(luminosity1, luminosity2)

plt.plot(z_sn, luminosity1, '-', label='quad')
plt.plot(z_sn, luminosity2, '--', label='tanh-sinh')
plt.xlabel('z_sn')
plt.ylabel('luminosity')
plt.legend()

enter image description here