Guidance with a model that works with grayscale images

91 Views Asked by At

I have a model that consists in 3 conv2D layers, and ReLU activations. It takes as an input grayscale images that are normalised to the interval [0,1]. The input images have some black regions, some white ones, and others in between.

However, the outputs dynamic range is compressed to [0.4,0.401]. All images are grey, even after a renormalisation to bring them back to [0,255].

I’m a bit lost and I don’t find why this is the case.

What I tried:

  • Plot histograms of gradients, although I’m not sure how to interpret this. It seems that for some parameters the model stops learning after a few epochs. For others one layer stops learning.
  • Modify kernel sizes
  • Add more layers
  • Try quite a lot of different values of learning rate

Below you will find my model and my training loop.

class Model(nn.Module):
    
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
           nn.Conv2d(in_channels=1, out_channels=64, kernel_size=13, stride=1, padding=5),
           nn.ReLU(True),
           nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=2),
           nn.ReLU(True),
           nn.Conv2d(in_channels=32, out_channels=1, kernel_size=5, stride=1, padding=2),
          )


    def forward(self, input_image):
        output_image = self.model(input_image)
        return output_image
def train_one_epoch(
    model: nn.Module,
    dataloader:  torch.utils.data.DataLoader,
    loss_fn: nn.MSELoss,
    optimizer: optim.Adam,
    epoch_index,
    scaler: amp.GradScaler, 
):
   
    model.train()
    running_loss = 0
    last_loss = 0
    batch_index = 0

    for batch, loader in enumerate(train_dataloader):
                
        input_img = loader['input'].to(device, non_blocking=True)
        gt_img = loader['gt'].to(device, non_blocking=True)

        model.zero_grad(set_to_none=True)
      
        with amp.autocast():
            output = model(input_img)
            loss = loss_fn(output, gt_img) 
       
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()
        if batch_index % 100 == 99:
            last_loss = running_loss / 1000
            running_loss = 0.0
        
        batch_index += 1

    return last_loss

Any help is appreciated. Please let me know if more information is needed.

0

There are 0 best solutions below