Classification in pytorch-forecastings temporal fusion transformer

949 Views Asked by At

I am implementing a TFT model and came across this table: https://pytorch-forecasting.readthedocs.io/en/stable/models.html

It states that a TFT model can be used for classification tasks which seems unintuitive to me as it is used for time series forecasting which is typically a regression task.

I have two questions in my head:

  1. What do you think: Does it make sense to use a TFT model for classification?
  2. I implemented it using BCEWithLogitsLoss as a loss function and setting the pos_weights parameter to weigh positive labels higher because of a zero value inflated dataset:
positives = np.sum(train_data['fridge'].values == 1)
negatives = np.sum(train_data['fridge'].values == 0)
positive_weight = torch.tensor(negatives/positives, dtype=torch.float)

tft = TemporalFusionTransformer.from_dataset(
    training,
    learning_rate=LEARNING_RATE,
    lstm_layers=2,
    hidden_size=16,
    attention_head_size=4,
    dropout=0.2,
    hidden_continuous_size=8,
    output_size=1,
    loss=convert_torchmetric_to_pytorch_forecasting_metric(
                                                           torch.nn.BCEWithLogitsLoss(
                                                           pos_weight=positive_weight)),
    log_interval=10,
    reduce_on_plateau_patience=4,
)

However, now it is predicting negative values as the TFT model uses ReLu as a activation function and i can not change it so something like sigmoid. Do you know how to overcome this issue and get a usable classification out of the TFT model?

1

There are 1 best solutions below

0
Philippe Ostiguy On

1- I noticed that PyTorch Forecasting supports classification for TFT, but originally, the TFT model was designed for regression, not classification. Generally speaking, you should use a model for its intended purpose.

2- PyTorch Forecasting provides an example of how to implement a custom model for classification:
https://pytorch-forecasting.readthedocs.io/en/stable/tutorials/building.html#Classification.
You could do the same with the TFT model:
https://pytorch-forecasting.readthedocs.io/en/stable/_modules/pytorch_forecasting/models/temporal_fusion_transformer.html#TemporalFusionTransformer.

Also, make sure to set output_size=2 for two classes.