I have written a class that uses "shap" library to compute and plot shap feature importance. I have also written an added functionality to plot the graphs in the same window and different tabs. However, my first tab is plotting the same graph twice. The second plot is then also carried over to other tabs. How can I get rid of the second plot? The figure below shows how two plots are created in the first tab.
import shap
from sklearn.model_selection import train_test_split
from IPython.display import clear_output
class SHAPInterpreter:
"""
A class that builds on top of the SHAP library to compute and plot SHAP values for a LightGBM model.
"""
def __init__(self, model, X, y, downsample=False, sample_frac=0.2, random_state=None):
"""
Initialize the SHAPInterpreter.
Parameters:
model (lightgbm.LGBMModel): A trained LightGBM model.
X (pandas.DataFrame): The feature matrix.
y (pandas.Series): The target vector.
downsample (bool): Whether to downsample the data.
sample_frac (float): The fraction of data to use for plotting.
random_state (int): The random state to use for downsampling.
"""
# Inititialize the JS visualization code (for Jupyter notebooks) and load JS in the notebook environment
# Do
self.model = model
self.X = X
self.y = y
self.explainer = shap.TreeExplainer(model)
if downsample:
self.X, _, self.y, _ = train_test_split(self.X, self.y, test_size=sample_frac, stratify=self.y, random_state=random_state)
self.shap_values = self.explainer.shap_values(self.X)
self.feature_names = self.X.columns.tolist()
# Function that plots the SHAP summary plot with following parameters:
# max_display: The maximum number of features to display
# feature_names: The names of the features to display. Works with max_display. Only the features in feature_names will be displayed.
# plot_type: It can be 'dot' or 'bar'
# color_bar: Whether to show the color bar
def summary_plot(self, max_display=10, feature_names=None, plot_type='dot', color_bar=False):
"""Generate a SHAP summary plot."""
# If feature_names is provided, filter the data and SHAP values
if feature_names is not None:
feature_indices = [self.X.columns.get_loc(name) for name in feature_names]
shap_values = self.shap_values[:, feature_indices]
X = self.X.iloc[:, feature_indices]
else:
shap_values = self.shap_values
X = self.X
shap.summary_plot(
shap_values,
X,
max_display=max_display,
plot_type=plot_type,
color_bar=color_bar,
show=False,
plot_size=(10, 6))
plt.show()
# Function that creates an interactive SHAP summary plot with following interactions:
# - A slider for the number of features to display
# - A dropdown with checkboxes for selecting the features names to display
# - A button for selecting the type of plot (dot or bar)
# Function that plots the SHAP dependence plot
def dependence_plot(self, feature, interaction_index=None, show=True):
"""Generate a SHAP dependence plot."""
return shap.dependence_plot(feature, self.shap_values, self.X, interaction_index=interaction_index, show=show)
def interactive_summary_plot2(self):
"""Create an interactive SHAP summary plot."""
# Create a slider for the number of features
slider = widgets.IntSlider(
value=min(10, self.X.shape[1]),
min=1,
max=self.X.shape[1],
step=1,
description='Number of features:',
)
# Create a dropdown for the plot type
dropdown = widgets.Dropdown(
options=['dot', 'bar'],
value='dot',
description='Plot type:',
)
# Create a placeholder for the checkboxes
checkboxes = {}
checkboxes_box = widgets.VBox(
layout=widgets.Layout(overflow_y='scroll'))
# checkboxes_box = widgets.HBox(layout=widgets.Layout(overflow_x='scroll', width='500px', border='solid 1px'))
# Create a button for updating the plot
button = widgets.Button(description='Update plot')
# Create an output widget to display the plot
out = widgets.Output()
# Create a function to update the checkboxes based on the slider value
# def update_checkboxes(change):
# num_features = change['new']
# checkboxes.clear()
# checkboxes.update({col: widgets.Checkbox(value=(i < num_features), description=col) for i, col in enumerate(self.X.columns[:num_features])})
# checkboxes_box.children = [widgets.VBox(list(checkboxes.values()), layout=widgets.Layout(overflow_y='scroll', height='100px', border='solid 1px'))]
def update_checkboxes(change):
num_features = change['new']
checkboxes.clear()
# Calculate the mean absolute SHAP value for each feature
mean_shap_values = np.abs(self.shap_values).mean(axis=0)
# Get the feature names sorted by the mean absolute SHAP value
sorted_feature_names = self.X.columns[np.argsort(mean_shap_values)[::-1]]
# Create the checkboxes based on the sorted feature names
checkboxes.update({col: widgets.Checkbox(value=(i < num_features), description=col) for i, col in enumerate(sorted_feature_names[:num_features])})
checkboxes_box.children = [
widgets.Label(value='Select features to display:'),
widgets.VBox(list(checkboxes.values()),
layout=widgets.Layout(overflow_y='scroll', height='150px', border='solid 1px'))]
# Attach the update_checkboxes function to the slider's value change event
slider.observe(update_checkboxes, names='value')
# Create a function to update the plot based on the selected number of features, feature names and plot type
def update_plot(button):
with out:
clear_output(wait=True)
num_features = slider.value
plot_type = dropdown.value
feature_names = [name for name, checkbox in checkboxes.items() if checkbox.value]
self.summary_plot(max_display=num_features, feature_names=feature_names, plot_type=plot_type, color_bar=True)
# Attach the update_plot function to the button's click event
button.on_click(update_plot)
# Initialize the checkboxes
update_checkboxes({'new': slider.value})
# Display the slider, the checkboxes, the dropdown, the button and the output widget
display(slider, checkboxes_box, dropdown, button, out)
# Instantiate the SHAPInterpreter class
shap_interpreter = SHAPInterpreter(
modeler.best_model,
modeler.test_set[0],
modeler.test_set[1],
downsample=True,
sample_frac=0.2,
random_state=139)
# Verify the shap summary plot function
feature_names = ['DTB_cnt_12mth', 'DTB_cnt_6mth', 'DTB_cnt_4wk']
# shap_interpreter.summary_plot(max_display=2, feature_names=feature_names, plot_type='dot', color_bar=True)
# Verify the interactive shap summary plot function
# shap_interpreter.interactive_summary_plot2()
import ipywidgets as widgets
def add_to_tab(tab, title):
def decorator(func):
def wrapper(*args, **kwargs):
# Check if a tab with the same title already exists
for i in range(len(tab.children)):
if tab.get_title(i) == title:
# If it does, use the existing tab
break
else:
# If it doesn't, create a new tab
tab.children += (widgets.Output(),)
tab.set_title(len(tab.children) - 1, title)
# Set i to the index of the new tab
i = len(tab.children) - 1
with tab.children[i]:
func(*args, **kwargs)
return wrapper
return decorator
def run_functions_in_tabs(func_dict, tab=None):
if tab is None:
tab = widgets.Tab()
display(tab)
for title, func_info in func_dict.items():
func = func_info.get('func')
args = func_info.get('args', [])
kwargs = func_info.get('kwargs', {})
# Use the add_to_tab decorator factory to call the function in a new tab
decorated_func = add_to_tab(tab, title)(func)
decorated_func(*args, **kwargs)
# return tab widget to reuse in other cells
# return tab
# Create a dictionary of methods and arguments
func_dict = {
'Summary Plot': {'func': shap_interpreter.interactive_summary_plot2},
'Dependence Plot': {'func': shap_interpreter.dependence_plot, 'args': ['DTB_cnt_8wk']},
'ROC Curve': {'func': modeler.plot_roc_curve}
}
# Run the methods in new tabs
run_functions_in_tabs(func_dict)
