Change outline color of 3d scatterplot points based on kmeans cluster

318 Views Asked by At

As the title says. The data contains RGB values and prints it out in a 3d scatterplot. When I run the kmeans clustering, I can get the points to print out, but I would like to figure out how to plot each centroid in a different color, and outline each data point in a color that matches its centroid color.

When I run this:

# Import scikit-learn, a machine learning library.
from sklearn.cluster import KMeans

# Load our classifier. 
num_clusters = 5 # You can change this if you want more/less than 5 bins!

# Fit to our data.
clusters_by_set = {}
for dataset_name, points in datasets.items():
  kmeans_cluster = KMeans(n_clusters=num_clusters, random_state=0)
  clusters_by_set[dataset_name] = kmeans_cluster.fit(points[['red','green','blue']])

# Add cluster centers to scatter.
for dataset_name, points in datasets.items():
  clusters = pd.DataFrame(clusters_by_set[dataset_name].cluster_centers_,
                          columns=['red','green','blue'])

  fig = create_3d_scatter(points, dataset_name)

  # # Maybe color all points to match their cluster color?
  fig.add_trace(dict(type='scatter3d',
            x=clusters['red'],
            y=clusters['green'],
            z=clusters['blue'])
            )

  pio.show(fig)

I get this:

3d scatterplot with kmeans centroids in red

I want to have a different color for each centroid and the outline of each point to match its centroid color, but I am completely stuck.

1

There are 1 best solutions below

0
Derek O On

When you iterate through each dataset, you can apply kmeans_cluster.fit_predict and obtain the predictions (which will be an array with numbers corresponding to each cluster). And we can place the predictions in a new column called 'cluster' in a DataFrame – which will help us with visualization.

Then if we use px.scatter_3d and pass the argument color='cluster', this will color the data points by the cluster according to kmeans clustering algorithm we applied earlier.

The last step is to add each centroid – but add each trace individually so that you can specify a different color for each one (and so that they show up in the legend). You can use the same color scheme as the plotly default so that they match your data – I might also make the marker size or opacity different so you can tell the centroids apart from the data points.

import numpy as np
import pandas as pd

import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio

# Import scikit-learn, a machine learning library.
from sklearn.cluster import KMeans

## generate some random data that will have clusters
np.random.seed(42)
sample_points = pd.DataFrame({
    'red': np.random.uniform(low=0.0, high=100.0, size=40),
    'green': np.random.uniform(low=100.0, high=200.0, size=40),
    'blue': np.random.uniform(low=200.0, high=256.0, size=40)
})

datasets = {
    'sample': sample_points
}

# Load our classifier. 
num_clusters = 5 # You can change this if you want more/less than 5 bins!

# Fit to our data.
clusters_by_set = {}

## the default plotly colors in order
default_colors = px.colors.qualitative.Plotly

for dataset_name, points in datasets.items():
    kmeans_cluster = KMeans(n_clusters=num_clusters, random_state=0, n_init='auto')
    clusters_by_set[dataset_name] = kmeans_cluster.fit(points[['red','green','blue']])

    ## avoid modifying the original data set
    points_with_predictions = points.copy()
    points_with_predictions['cluster'] = kmeans_cluster.fit_predict(sample_points)
    ## sort these in order so the clusters are plotted in order
    points_with_predictions = points_with_predictions.sort_values(by='cluster')
    points_with_predictions['cluster'] = points_with_predictions['cluster'].astype('category')

    fig = px.scatter_3d(points_with_predictions, x='red', y='green', z='blue', color='cluster')

    # Add cluster centers to scatter.
    clusters = pd.DataFrame(clusters_by_set[dataset_name].cluster_centers_,
                            columns=['red','green','blue'])

    ## the index of clusters should work as the identifier for the cluster
    for i, red, green, blue in clusters.itertuples():
        fig.add_trace(go.Scatter3d(
            x=[red],
            y=[green],
            z=[blue],
            name=f'cluster {i} – centroid',
            marker=dict(color=default_colors[i], size=40, opacity=0.5)
        ))

pio.show(fig)

enter image description here