Why mulitple plots are generated in the first tab?

35 Views Asked by At

I have written a class that uses "shap" library to compute and plot shap feature importance. I have also written an added functionality to plot the graphs in the same window and different tabs. However, my first tab is plotting the same graph twice. The second plot is then also carried over to other tabs. How can I get rid of the second plot? The figure below shows how two plots are created in the first tab.

enter image description here

    import shap
    from sklearn.model_selection import train_test_split
    from IPython.display import clear_output
    
    class SHAPInterpreter:
        """
        A class that builds on top of the SHAP library to compute and plot SHAP values for a LightGBM model.
        """
    
        def __init__(self, model, X, y, downsample=False, sample_frac=0.2, random_state=None):
            """
            Initialize the SHAPInterpreter.
    
            Parameters:
            model (lightgbm.LGBMModel): A trained LightGBM model.
            X (pandas.DataFrame): The feature matrix.
            y (pandas.Series): The target vector.
            downsample (bool): Whether to downsample the data.
            sample_frac (float): The fraction of data to use for plotting.
            random_state (int): The random state to use for downsampling.
            """
            # Inititialize the JS visualization code (for Jupyter notebooks) and load JS in the notebook environment
            # Do 
            self.model = model
            self.X = X
            self.y = y
            self.explainer = shap.TreeExplainer(model)
            if downsample:
                self.X, _, self.y, _ = train_test_split(self.X, self.y, test_size=sample_frac, stratify=self.y, random_state=random_state)
            self.shap_values = self.explainer.shap_values(self.X)
            self.feature_names = self.X.columns.tolist()
    
        # Function that plots the SHAP summary plot with following parameters:
        # max_display: The maximum number of features to display
        # feature_names: The names of the features to display. Works with max_display. Only the features in feature_names will be displayed.
        # plot_type: It can be 'dot' or 'bar'
        # color_bar: Whether to show the color bar
        def summary_plot(self, max_display=10, feature_names=None, plot_type='dot', color_bar=False):
            """Generate a SHAP summary plot."""
    
            # If feature_names is provided, filter the data and SHAP values
            if feature_names is not None:
                feature_indices = [self.X.columns.get_loc(name) for name in feature_names]
                shap_values = self.shap_values[:, feature_indices]
                X = self.X.iloc[:, feature_indices]
            else:
                shap_values = self.shap_values
                X = self.X
    
            shap.summary_plot(
                shap_values, 
                X, 
                max_display=max_display, 
                plot_type=plot_type,
                color_bar=color_bar,
                show=False,
                plot_size=(10, 6))
            plt.show()
    
        # Function that creates an interactive SHAP summary plot with following interactions:
        # - A slider for the number of features to display
        # - A dropdown with checkboxes for selecting the features names to display
        # - A button for selecting the type of plot (dot or bar)
        # Function that plots the SHAP dependence plot
        def dependence_plot(self, feature, interaction_index=None, show=True):
            """Generate a SHAP dependence plot."""
            return shap.dependence_plot(feature, self.shap_values, self.X, interaction_index=interaction_index, show=show)
     
        def interactive_summary_plot2(self):
            """Create an interactive SHAP summary plot."""
            # Create a slider for the number of features
            slider = widgets.IntSlider(
                value=min(10, self.X.shape[1]),
                min=1,
                max=self.X.shape[1],
                step=1,
                description='Number of features:',
            )
    
            # Create a dropdown for the plot type
            dropdown = widgets.Dropdown(
                options=['dot', 'bar'],
                value='dot',
                description='Plot type:',
            )
    
            # Create a placeholder for the checkboxes
            checkboxes = {}
            checkboxes_box = widgets.VBox(
                layout=widgets.Layout(overflow_y='scroll'))
            # checkboxes_box = widgets.HBox(layout=widgets.Layout(overflow_x='scroll', width='500px', border='solid 1px'))
    
            # Create a button for updating the plot
            button = widgets.Button(description='Update plot')
    
            # Create an output widget to display the plot
            out = widgets.Output()
    
            # Create a function to update the checkboxes based on the slider value
            # def update_checkboxes(change):
            #     num_features = change['new']
            #     checkboxes.clear()
            #     checkboxes.update({col: widgets.Checkbox(value=(i < num_features), description=col) for i, col in enumerate(self.X.columns[:num_features])})
            #     checkboxes_box.children = [widgets.VBox(list(checkboxes.values()), layout=widgets.Layout(overflow_y='scroll', height='100px', border='solid 1px'))]
    
            def update_checkboxes(change):
                num_features = change['new']
                checkboxes.clear()
                
                # Calculate the mean absolute SHAP value for each feature
                mean_shap_values = np.abs(self.shap_values).mean(axis=0)
                
                # Get the feature names sorted by the mean absolute SHAP value
                sorted_feature_names = self.X.columns[np.argsort(mean_shap_values)[::-1]]
                
                # Create the checkboxes based on the sorted feature names
                checkboxes.update({col: widgets.Checkbox(value=(i < num_features), description=col) for i, col in enumerate(sorted_feature_names[:num_features])})
                checkboxes_box.children = [
                    widgets.Label(value='Select features to display:'),
                    widgets.VBox(list(checkboxes.values()), 
                    layout=widgets.Layout(overflow_y='scroll', height='150px', border='solid 1px'))]
    
                
            # Attach the update_checkboxes function to the slider's value change event
            slider.observe(update_checkboxes, names='value')
    
            # Create a function to update the plot based on the selected number of features, feature names and plot type
            def update_plot(button):
                with out:
                    clear_output(wait=True)
                    num_features = slider.value
                    plot_type = dropdown.value
                    feature_names = [name for name, checkbox in checkboxes.items() if checkbox.value]
                    self.summary_plot(max_display=num_features, feature_names=feature_names, plot_type=plot_type, color_bar=True)
    
            # Attach the update_plot function to the button's click event
            button.on_click(update_plot)
    
            # Initialize the checkboxes
            update_checkboxes({'new': slider.value})
    
            # Display the slider, the checkboxes, the dropdown, the button and the output widget
            display(slider, checkboxes_box, dropdown, button, out)
    
    
    # Instantiate the SHAPInterpreter class
    shap_interpreter = SHAPInterpreter(
        modeler.best_model,
        modeler.test_set[0],
        modeler.test_set[1],
        downsample=True,
        sample_frac=0.2,
        random_state=139)
    
    # Verify the shap summary plot function
    feature_names = ['DTB_cnt_12mth', 'DTB_cnt_6mth', 'DTB_cnt_4wk']
    # shap_interpreter.summary_plot(max_display=2, feature_names=feature_names, plot_type='dot', color_bar=True)
    
    # Verify the interactive shap summary plot function
    # shap_interpreter.interactive_summary_plot2()
    
    import ipywidgets as widgets
    
    def add_to_tab(tab, title):
        def decorator(func):
            def wrapper(*args, **kwargs):
                # Check if a tab with the same title already exists
                for i in range(len(tab.children)):
                    if tab.get_title(i) == title:
                        # If it does, use the existing tab
                        break
                else:
                    # If it doesn't, create a new tab
                    tab.children += (widgets.Output(),)
                    tab.set_title(len(tab.children) - 1, title)
                    # Set i to the index of the new tab
                    i = len(tab.children) - 1
                
                with tab.children[i]:
                    func(*args, **kwargs)
            return wrapper
        return decorator
    
    def run_functions_in_tabs(func_dict, tab=None):
        if tab is None:
            tab = widgets.Tab()
            display(tab)
            
        for title, func_info in func_dict.items():
            func = func_info.get('func')
            args = func_info.get('args', [])
            kwargs = func_info.get('kwargs', {})
    
            # Use the add_to_tab decorator factory to call the function in a new tab
            decorated_func = add_to_tab(tab, title)(func)
            decorated_func(*args, **kwargs)
    
        # return tab widget to reuse in other cells
        # return tab
    
    # Create a dictionary of methods and arguments
    func_dict = {
        'Summary Plot': {'func': shap_interpreter.interactive_summary_plot2},
        'Dependence Plot': {'func': shap_interpreter.dependence_plot, 'args': ['DTB_cnt_8wk']},
        'ROC Curve': {'func': modeler.plot_roc_curve}
    }
    
    # Run the methods in new tabs
    run_functions_in_tabs(func_dict)

0

There are 0 best solutions below