I have a semantic segmentation task, which I'm solving using PyTorch. I use (dice loss + BCE) as loss function. I know that each image has exactly one mask and I want to do additional penalty if mask was separated into several parts. Which loss function I can use for it?
Loss function for semantic segmentation which do penalty for mask separation
96 Views Asked by ZFTurbo At
2
There are 2 best solutions below
0
On
I would like to add my 2 cents. Some models achieve this by using a Postprocessing step.
Instead of penalising the model by modifying the loss, they keep it as it is. After the model has made the prediction mask, we remove all but the largest connected component from the mask. This ensures 1 mask per prediction and prevents disconnected masks. Do check if this approach is suitable for your use case.
There are libraries such as Connected-Components-3D for doing this for 3D masks. OpenCV has support for this as well See this answer.
If we want to penalize the model for predicting disconnected components, we make sure the penalty should be less severe than not producing a mask(s) at all.
So I think a safe way to do that is by penalizing false negative pixels. In this direction you may explore Focal loss (https://paperswithcode.com/method/focal-loss). This loss adds penalty to pixels that cause a disconnection.
Also you may explore other custom differentiable losses. Given you always only have one output mask, you may try out a loss that penalizes more contours. Since you apply the loss before thresholding, you can try out Sum/Mean of Gradients (use bigger shifts unlike used in edge detection).