Can't replicate clustermap plot using just sns.heatmap

60 Views Asked by At

Given a linkage matrix Z the resulting heatmap I get from g = sns.clustermap(corr) is different from the heatmap I get using sns.heatmap(corr[np.ix_(g.dendrogram_row.reordered_ind, g.dendrogram_row.reordered_ind)]).

a minimal example of this issue, using the following function is:

from scipy.cluster.hierarchy import dendrogram
def get_linkage_matrix(children, n_leaves, distances): #modified from sklearn
    # Create linkage matrix
    # create the counts of samples under each node
    counts = np.zeros(children.shape[0])
    n_samples = n_leaves
    for i, merge in enumerate(children):
        current_count = 0
        for child_idx in merge:
            if child_idx < n_samples:
                current_count += 1  # leaf node
            else:
                current_count += counts[child_idx - n_samples]
        counts[i] = current_count

    linkage_matrix = np.column_stack(
        [children, distances, counts]
    ).astype(float)
    return linkage_matrix

np.random.seed(0)
toy_data = np.random.randint(0,5,(5,10))
children, n_connected_components, n_leaves, parents, distances = sklearn.cluster.ward_tree(toy_data.T, return_distance=True)
Z = get_linkage_matrix(children, n_leaves, distances) #using basically plot_dendrogram function from https://scikit-learn.org/stable/auto_examples/cluster/plot_agglomerative_dendrogram.html
plt.figure(figsize=(4,3))
R = dendrogram(Z)

with this output: ward dendrogram

So, then I try to get a nice sns.clustermap plot with the same dendrogram:

corr_toy = np.corrcoef(toy_data.T)
cmap = sns.diverging_palette(230, 20, as_cmap=True) 
g = sns.clustermap(corr_toy, method='ward',
                   vmax=.3, center=0, cmap=cmap,
                   row_linkage=Z, row_cluster=True, col_cluster=False,
                   linewidths=.5, cbar_kws={"shrink": .5}, row_colors=[R["leaves_color_list"][i] for i in R["leaves"]],
                   yticklabels=False, xticklabels=R["leaves"],
                   cbar_pos=(0.9, 0.3, 0.02, 0.3), figsize=(4,3))#nfeatures_df.iloc[:,R["leaves"]].columns)
mask = np.triu(np.ones_like(corr_toy, dtype=bool))
values = g.ax_heatmap.collections[0].get_array().reshape(corr_toy.shape)
new_values = np.ma.array(values, mask=mask)
g.ax_heatmap.collections[0].set_array(new_values)
plt.show()

with this output: sns.clustermap

where here I was already expecting my row_colors=[R["leaves_color_list"][i] for i in R["leaves"]] shown in the y-axis to be in the order of colors of the dendrogram from the first plot, which doesn't happen even though:

print(R["leaves"] == g.dendrogram_row.reordered_ind)
>>> True

And by visual inspection someone could also verify that the two dendrograms match. So, that is basically the first problem i cant solve out and any help would be very appreciated.
Second, when I try to replicate the heatmap using sns.heatmap and g.dendrogram_row.reordered_ind to sort the rows appropriately, that is:

plt.figure(figsize=(4,3))
mask = np.triu(np.ones_like(corr_toy, dtype=bool))
cmap = sns.diverging_palette(230, 20, as_cmap=True)
sns.heatmap(corr_toy[np.ix_(g.dendrogram_row.reordered_ind, g.dendrogram_row.reordered_ind)], mask=mask,
            cmap=cmap, vmax=.3, center=0,
            square=True, linewidths=.5, cbar_kws={"shrink": .5},
            xticklabels=g.dendrogram_row.reordered_ind, yticklabels=g.dendrogram_row.reordered_ind)

i get this output:sns.heatmap
which obviously fails to replicate the sns.clustermap plot above, and I cant see why.

0

There are 0 best solutions below