How to get centroids from Ball Tree?

52 Views Asked by At

According to the scikit-learn documentation, their sklearn.neighbors.BallTree class

recursively divides the data into nodes defined by a centroid C and radius r, such that each point in the node lies within the hyper-sphere defined by C and r.

Is there a way, given a BallTree instance, to extract those centroids without recomputing them? The get_arrays() method exposes the radii and the datapoints contained by each node, but does not expose the centroids. In Euclidean space, one could easily compute centroids by averaging all of the datapoints in each node, but this becomes harder in other metrics. Furthermore, it doesn't seem necessary that a user should have to perform this computation if the BallTree instance has already done so internally.

1

There are 1 best solutions below

1
Ben Reiniger On BEST ANSWER

The last array in the result of the get_arrays method is the "node bounds" array, which contains the centroids:

node_bounds : the [* x n_nodes x n_features] array containing the node bound information. For ball tree, the first dimension is 1, and each row contains the centroid of the node. [...]

[source]