I applied a GCN for graph binary classification, and I want to use GNNexplainer to explain the prediction. But I don't know how to apply GNNexplainer on my model. Here is my GCN model code:
class GCNModel(nn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(GCNModel, self).__init__()
self.conv1 = GCNConv(input_dim, hidden_dim)
self.conv2 = GCNConv(hidden_dim, hidden_dim)
self.linear1 = nn.Linear(hidden_dim, output_dim)
def forward(self, data):
x, edge_index, edge_weight, batch = data.x, data.edge_index, data.edge_attr, data.batch
x = F.relu(self.conv1(x, edge_index, edge_weight))
x = self.conv2(x, edge_index, edge_weight)
x = global_mean_pool(x, batch)
x= `self.linear1(x)
x = torch.sigmoid(x)
return x
Here is the GNNexplainer code:
for batch in test_loader:
graphs_batch, labels_batch = batch
# Choose the first graph from the batch for explanation
graph_to_explain = graphs_batch[0].to(device)
labels_batch = labels_batch.to(device)
model.eval()
with torch.no_grad():
prediction = model(graph_to_explain).squeeze()
data = Data(x=graph_to_explain, edge_index=graph_to_explain.edge_index)
explainer = Explainer(
model=model,
algorithm=GNNExplainer(epochs=200),
explanation_type='model',
node_mask_type='attributes',
edge_mask_type='object',
model_config=dict(
mode='binary_classification',
task_level='graph',
return_type='probs',
),
)
explanation = explainer(data.x, data.edge_index, data.batch)
print("Node Importance Scores:", explanation.node_importance)
print("Edge Importance Mask:", explanation.edge_mask)
break
Then the error is like this :
"---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[27], line 31
18 # Initialize the GNNExplainer with your model and algorithm settings
19 explainer = Explainer(
20 model=model,
21 algorithm=GNNExplainer(epochs=200),
(...)
29 ),
30 )
---> 31 explanation = explainer(data.x, data.edge_index, data.batch)
33 # explanation = explainer(
34 # x=data.x,
35 # edge_index=data.edge_index,
(...)
40 # You can use the explanation for visualization or analysis
41 # Access the importance scores using explanation.node_importance and explanation.edge_mask
42 print("Node Importance Scores:", explanation.node_importance)
TypeError: Explainer.__call__() takes 3 positional arguments but 4 were given"
How to use GNNexplainer for graph classification in the latest version of torch_geometric (2.4.0)?