The problem of having a hard-to-read SHAP chart

63 Views Asked by At

I plotted a SHAP chart for a dataset with six inputs and two outputs, 1000 observations. I also chose a Waterfall type SHAP chart.

The problem I got is that the dots appear on top of each other (vertically) as shown in the attached figure. enter image description here So I find it hard to interpret the effect of inputs on the outputs.

This is my code:

dataset = 
x = df.iloc[:, 0:6]
y = df.iloc[:, [6, 7]]
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=0)
rf_model = RandomForestRegressor()
rf_pipeline = Pipeline([
    ('scaler', StandardScaler()),
    ('regressor', rf_model)
])
rf_pipeline.fit(x_train, y_train)

explainer = shap.TreeExplainer(rf_model)
shap_values = explainer.shap_values(x_test)

shap.summary_plot(shap_values[0], x_test)
plt.title('SHAP Values for Output 1')

shap.summary_plot(shap_values[1], x_test)
plt.title('SHAP Values for Output 2')

The problem is most likely related to the range of the data, because when the chart is drawn the x-axis starts from approximately -350 to 30. When I reduce the range, the points for the first variable (the one at the top) disappear.

Have any of you encountered such a problem before, and how did you solve it?

0

There are 0 best solutions below