I have an ONNX model that takes input [1, 35, 4], which is [batch_size, num_channels, seq_len], and outputs [1, 3], which is [batch_size, num_classes]. What I need is to assign a number to each channel telling me how important it is for making the predictions the model is making. I thought I would use shap's PermutationExplainer for this purpose, but I'm not sure how to let it know that I don't care about seq_len.
I tried doing this:
import onnxruntime as ort
import shap
import numpy as np
# Load the ONNX model
model_path = 'explain-example.onnx'
sess = ort.InferenceSession(model_path)
# Define the model function to handle batching
def model(x):
x = x.reshape(-1, 35, 4).astype(np.float32)
# Model expects input [1, 35, 4] and returns [1, 3]
outputs = [sess.run(None, {'input': x[i:i+1]})[0] for i in range(x.shape[0])]
return np.concatenate(outputs, axis=0)
# Create sample input
X = np.random.rand(1000, 35, 4).astype(np.float32)
output_names = [f"output_{i}" for i in range(3)]
feature_names = [f"Channel {i}" for i in range(35)]
def masker_fn(mask, x):
# problem, mask shape is (140,) and x shape is (35, 4)
masked_x = x.copy()
for i in range(x.shape[1]):
if mask[i] == 0:
masked_x[:, i, :] = 0
return masked_x
# Declare the explainer using the custom masker for channels
explainer = shap.PermutationExplainer(model, masker_fn, feature_names=feature_names, output_names=output_names)
# Compute shap values
shap_values = explainer(X)
but the problem is that the masks being sent to the masker have shape (140,) instead of (35,). I could of course somehow merge the masks across seq_len dimension, but I am thinking that 1. unnecessary work is done by the explainer and 2. maybe this will break its inner algorithm somehow.
How can I properly tell it that no, there aren't 140 features, but 35?