Calclulate the Mahalanobis distance with org.apache.commons.math3 only

233 Views Asked by At

Is there a way to calculate Mahalanobis distance ommiting org.apache.mahout usage (by using only org.apache.commons.math3)?

1

There are 1 best solutions below

0
AndreyP On

Way to to use only org.apache.commons.math3 methods is following:

import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;

...

double[][] data = new double[][]{
...
};


RealMatrix covarianceMatrix = new Covariance(data, false).getCovarianceMatrix();
RealMatrix inverseCovarianceMatrix = MatrixUtils.inverse(covarianceMatrix);

VectorialMean Mean = new VectorialMean(covarianceMatrix.getColumnDimension());
Arrays.stream(data).forEach(x -> Mean.increment(x));
RealVector meanVector = new ArrayRealVector(Mean.getResult());


//Calculate Mahalanobis distance for first row
RealVector v = new ArrayRealVector(data[0]);

double distance = Math.sqrt(v.minus(meanVector).dot(Algebra.mult(inverseCovarianceMatrix, v.minus(meanVector))));

Class to use instead of org.apache.mahout.common.distance.MahalanobisDistanceMeasure:

package Demo;

import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;

public class MahalanobisDistanceMeasure {

    RealMatrix inverseCovarianceMatrix;

    RealVector meanVector;

    public void setCovarianceMatrix(RealMatrix covarianceMatrix) {
        this.inverseCovarianceMatrix = MatrixUtils.inverse(covarianceMatrix);
    }

    public void setMeanVector(RealVector meanVector) {
        this.meanVector = meanVector;
    }

    public double distance(RealVector vector) {
        RealVector subtract = vector.subtract(meanVector);

        return Math.sqrt(subtract.dotProduct(inverseCovarianceMatrix.operate(subtract)));
    }
}

Usage of the class:

MahalanobisDistanceMeasure measure = new MahalanobisDistanceMeasure();
measure.setCovarianceMatrix(covarianceMatrix);
measure.setMeanVector(meanVector);

double distance = measure.distance(vector);