How to extract the distance and transport matrices from Scipy's wasserstein_distance?

234 Views Asked by At

The scipy.stats.wasserstein_distance function only returns the minimum distance (the solution) between two input distributions, p and q. But that distance is the result of the product of a distance matrix and an optimal transport matrix that must have been computed inside the same function.

How can I extract the distance matrix and optimal transport matrix that correspond to the solution as 2nd and 3rd output arguments?

1

There are 1 best solutions below

1
On

It does not seem that you can get the calculated transport matrix from scipy's wasserstein_distance. You can get it via other packages though, like https://github.com/wmayner/pyemd. I have been using this package for a while and it works pretty fine, while also executing very quickly. Look into the function emd_with_flow() within section Usage.

Then the distance matrix is an input of the EMD calculation, not an output.