How the losses are being computed for CATS algorithm

30 Views Asked by At

I am struggling when trying to compute the losses when using CATS algorithm in VowpalWabit library. Does anyone knows how it is computed (the average and the last)

I tried to calculate the average of the Cost as I found in the documentation (loss = cost = -reward)

1

There are 1 best solutions below

0
Idriss Ben Hmida On

The CATS algorithm computes the loss that is reported using the get_loss() function, here: https://github.com/VowpalWabbit/vowpal_wabbit/blob/master/vowpalwabbit/core/src/reductions/cats.cc#L58-L82.

What it does can be broken down into a few steps:

Normalize and Discretize the set of actions (first by their interval, then into buckets when we floor the "ac"). What this does is turn the chosen action into an "action index" much like that of the standard CB algorithm This is used to compute the "center" position of the action - in other words, if we always chose the center, rather than some part within the bandwidth of the discretization. Then we compare the logged action with this center, and if the logged action falls within the bandwidth, we compute a loss (this functions equivalently to the indicator function). If we are computing the loss, we need to ensure that we properly account for actions whose bandwidth exceeds the min/max allowed, and then use this, along with the logged probability of choosing an action, to perform an IPS-like computation over the cost of choosing the action.