How to calculate mutual information in PyTorch (differentiable estimator)

3.8k Views Asked by At

I am training a model with pytorch, where I need to calculate the degree of dependence between two tensors (let's say they are the two tensors each containing values very close to zero or one, e.g. v1 = [0.999, 0.998, 0.001, 0.98] and v2 = [0.97, 0.01, 0.997, 0.999]) as a part of my loss function. I am trying to calculate mutual information, but I can't find any mutual information estimation implementation in PyTorch. Has such a thing been provided anywhere?

3

There are 3 best solutions below

2
Umang Gupta On

Mutual information is defined for distribution and not individual points. So, I will write the next part assuming v1 and v2 are samples from a distribution, p. I will also take that you have n samples from p, n>1.

You want a method to estimate mutual information from samples. There are many ways to do this. One of the simplest ways to do this would be to use a non-parametric estimator like NPEET (https://github.com/gregversteeg/NPEET). It works with numpy (you can convert from torch to numpy for this). There are more involved parametric models for which you may be able to find implementation in pytorch (See https://arxiv.org/abs/1905.06922).

If you only have two vectors and want to compute a similarity measure, a dot product similarity would be more suitable than mutual information as there is no distribution.

0
cut936 On

It is not provided in the official Pytorch code, but here is a pytorch implementation that uses kernel density estimation for the histogram approximation. Note that this method is fully-differentiable.

Alternatively, you can also use the differentiable histogram functions in Kornia to compute the MI metric yourself if you want more control for whatever reason.

0
Hello Worlds On

Torchmetrics has MutualInfoScore.

Example from their docs:

>>> import torch
>>> from torchmetrics.clustering import MutualInfoScore
>>> preds = torch.tensor([2, 1, 0, 1, 0])
>>> target = torch.tensor([0, 2, 1, 1, 0])
>>> mi_score = MutualInfoScore()
>>> mi_score(preds, target)
tensor(0.5004)

Functional:

>>> from torchmetrics.functional.clustering import mutual_info_score
>>> target = torch.tensor([0, 3, 2, 2, 1])
>>> preds = torch.tensor([1, 3, 2, 0, 1])
>>> mutual_info_score(preds, target)
tensor(1.0549)