How to use GNNexplainer for graph classification in the latest version of torch-geometric (2.4.0)?

85 Views Asked by At

I applied a GCN for graph binary classification, and I want to use GNNexplainer to explain the prediction. But I don't know how to apply GNNexplainer on my model. Here is my GCN model code:

class GCNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GCNModel, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.linear1 = nn.Linear(hidden_dim, output_dim)


    def forward(self, data):
        x, edge_index, edge_weight, batch = data.x, data.edge_index, data.edge_attr, data.batch
        
        x = F.relu(self.conv1(x, edge_index, edge_weight))
        x = self.conv2(x, edge_index, edge_weight)
        x = global_mean_pool(x, batch)
        x= `self.linear1(x)
        x = torch.sigmoid(x)
        return x

Here is the GNNexplainer code:

for batch in test_loader:
    graphs_batch, labels_batch = batch
    # Choose the first graph from the batch for explanation
    graph_to_explain = graphs_batch[0].to(device)
    labels_batch = labels_batch.to(device)
    
    model.eval()
    with torch.no_grad():
        prediction = model(graph_to_explain).squeeze()

        data = Data(x=graph_to_explain, edge_index=graph_to_explain.edge_index)

        explainer = Explainer(
            model=model,
            algorithm=GNNExplainer(epochs=200),
            explanation_type='model',
            node_mask_type='attributes',
            edge_mask_type='object',
            model_config=dict(
                mode='binary_classification',
                task_level='graph',
                return_type='probs',
            ),
        )
        explanation = explainer(data.x, data.edge_index, data.batch)

        print("Node Importance Scores:", explanation.node_importance)
        print("Edge Importance Mask:", explanation.edge_mask)

    break

Then the error is like this :


"---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[27], line 31
     18 # Initialize the GNNExplainer with your model and algorithm settings
     19 explainer = Explainer(
     20     model=model,
     21     algorithm=GNNExplainer(epochs=200),
   (...)
     29     ),
     30 )
---> 31 explanation = explainer(data.x, data.edge_index, data.batch)
     33 # explanation = explainer(
     34 #     x=data.x,
     35 #     edge_index=data.edge_index,
   (...)
     40 # You can use the explanation for visualization or analysis
     41 # Access the importance scores using explanation.node_importance and explanation.edge_mask
     42 print("Node Importance Scores:", explanation.node_importance)

TypeError: Explainer.__call__() takes 3 positional arguments but 4 were given"

How to use GNNexplainer for graph classification in the latest version of torch_geometric (2.4.0)?

0

There are 0 best solutions below