Cannot find why my deeplabv3+ model shows bad performance

21 Views Asked by At

The DeepLabV3+(with PyTorch) model using the Xception backbone was implemented by referring to the paper as much as possible, but the performance of the training using the VOC2012 dataset is too low.

One-hot-processed datasets for SegmentationClass of datasets were used for learning, and the following aspects occurred when the loss and miou graph were checked using WandB:

  1. It shows a lower value of validation loss than training, and it shows a flat pattern rather than a decreasing pattern from the first epoch until it stops.
  2. miou is similarly flat, and verification miou hardly changes.

The optimizer applied AdamW, and the scheduler used ReduceLROnPlateau (mode='max') for miou.

I've tried a lot of other codes, but the results were always the same, so I couldn't find an answer anymore, so I left a question here.

I'll write down the model code here, so please try it and let me know if you've found the answer.

Separble convolution:

class SeparableConv2d(nn.Module):
  def __init__(
      self, in_channels, out_channels, kernel_size=3,
      stride=1, padding=0, dilation=1, bias=True, depthwise=False
      ):
    super().__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.kernel_size = kernel_size
    self.stride=stride
    self.padding = padding
    self.dilation = dilation
    self.depth = depthwise

    if self.depth:
      self.depthwise = nn.Sequential(
        nn.Conv2d(self.in_channels, self.in_channels, self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=in_channels, bias=False),
        nn.BatchNorm2d(self.in_channels),
        nn.ReLU()
      )
    else:
      self.depthwise = nn.Conv2d(self.in_channels, self.in_channels, self.kernel_size, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=in_channels, bias=False)

    self.pointwise=nn.Conv2d(self.in_channels, self.out_channels, kernel_size=1, bias=bias)

  def forward(self, x):
    out = self.depthwise(x)
    out = self.pointwise(out)
    return out

Residual block:

class residualBlock(nn.Module):
    def __init__(self, input, output, atrous:list, strides:list, relus:list,
                 residual=True, last_depth=True):
        super(residualBlock, self).__init__()
        self.input=input
        self.output=output
        self.residual = residual
        self.atrous = atrous
        self.strides = strides
        self.relus = relus
        self.last_depth = last_depth

        if self.atrous == None:
            atrous = [1]*3
        elif isinstance(self.atrous, int):
            atrous = [self.atrous]*3

        def block(input, output, padding, dilation, stride=1, relu=False, depth=False):
            if relu:
              return nn.Sequential(
                nn.ReLU(),
                SeparableConv2d(input, output, dilation=dilation, stride=stride, bias=False),
              )
            else:
              return SeparableConv2d(input, output, dilation=dilation, stride=stride, bias=False, depthwise=depth)


        #block3
        #get additional separable convolution instead of max pooling in original xception model

        def residualblock(input, output):
            return nn.Sequential(
                nn.Conv2d(input, output, kernel_size=1, stride=2, bias=False),
                nn.BatchNorm2d(output),
                nn.ReLU()
            )

        self.block1 = block(self.input, self.output, padding=atrous[0], dilation=atrous[0], stride=self.strides[0], relu=self.relus[0])
        self.block2 = block(self.output, self.output, padding=atrous[1], dilation=atrous[1], stride=self.strides[1], relu=self.relus[1])
        self.block3 = block(output, output, padding=atrous[2], dilation=atrous[2], stride=self.strides[2], relu=self.relus[2], depth = self.last_depth)
        self.residualblock = residualblock(input, output)

    def forward(self, x):
        res = x
        #block1
        x = self.block1(x)

        #block2
        x = self.block2(x)

        #block3
        x = self.block3(x)
        #add residual
        if self.residual:
            resblock = self.residualblock(res)
            #resizing x to resblock
            x = F.interpolate(x, size=resblock.size()[2:], mode='bilinear', align_corners=True)
            x = x + resblock
        else:
            x = F.interpolate(x, size=res.size()[2:], mode='bilinear', align_corners=True)
            x = x + res
        return x

Backbone:

class Xception(nn.Module):
    def __init__(self, nInputChannels=3, os=16):
        super(Xception, self).__init__()
        stride_list = None
        self.os = os
        self.input = nInputChannels
        if self.os == 8:
            stride_list = [2,1,1]
        elif self.os == 16:
            stride_list = [2,2,1]

        def strideconv(input, output, checkstride):
            if checkstride:
                return nn.Sequential(
                    nn.Conv2d(input, output, kernel_size=3, stride=2, bias=False),
                    nn.ReLU()
                )
            else:
                return nn.Sequential(
                    nn.Conv2d(input, output, kernel_size=3, bias=False),
                    nn.ReLU()
                )

        # Entry flow
        self.entry_conv1 = strideconv(self.input, 32, True)
        self.entry_conv2 = strideconv(32, 64, False)
        self.entry_conv3 = residualBlock(64, 128, atrous=stride_list[0], strides=[1,1,2], relus=[False, True, False])
        self.entry_conv4 = residualBlock(128, 256, atrous=stride_list[0], strides=[1,1,2], relus=[True, True, False])
        self.entry_conv5 = residualBlock(256, 728, atrous=stride_list[0], strides=[1,1,2], relus=[True, True, False])

        # Middle flow
        self.mid01 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
        self.mid02 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
        self.mid03 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
        self.mid04 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
        self.mid05 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
        self.mid06 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
        self.mid07 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
        self.mid08 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
        self.mid09 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
        self.mid10 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
        self.mid11 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
        self.mid12 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
        self.mid13 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
        self.mid14 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
        self.mid15 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)
        self.mid16 = residualBlock(728, 728, atrous=stride_list[1], residual=False, strides=[1,1,1], relus=[True,True,True], last_depth=False)

        # Exit flow
        self.exit_residual = residualBlock(728,1024,stride_list[2], strides=[1,1,2], relus=[True, True, False])
        self.exit_conv1 = SeparableConv2d(1024, 1536, kernel_size=3, stride=stride_list[2], bias=False)
        self.exit_relu1 = nn.ReLU()
        self.exit_conv2 = SeparableConv2d(1536, 1536, kernel_size=3, stride=stride_list[2], bias=False)
        self.exit_relu2 = nn.ReLU()
        self.exit_conv3 = SeparableConv2d(1536, 2048, kernel_size=3, stride=stride_list[2], bias=False)

    def forward(self, x):
        # Entry flow
        entry_out1 = self.entry_conv1(x)
        entry_out2 = self.entry_conv2(entry_out1)
        entry_out3 = self.entry_conv3(entry_out2)
        entry_out4 = self.entry_conv4(entry_out3)
        entry_out5 = self.entry_conv5(entry_out4)

        low_level_features = entry_out5

        # Middle flow
        mid_out01 = self.mid01(entry_out5)
        mid_out02 = self.mid02(mid_out01)
        mid_out03 = self.mid03(mid_out02)
        mid_out04 = self.mid04(mid_out03)
        mid_out05 = self.mid05(mid_out04)
        mid_out06 = self.mid06(mid_out05)
        mid_out07 = self.mid07(mid_out06)
        mid_out08 = self.mid08(mid_out07)
        mid_out09 = self.mid09(mid_out08)
        mid_out10 = self.mid10(mid_out09)
        mid_out11 = self.mid11(mid_out10)
        mid_out12 = self.mid12(mid_out11)
        mid_out13 = self.mid13(mid_out12)
        mid_out14 = self.mid13(mid_out13)
        mid_out15 = self.mid15(mid_out14)
        mid_out16 = self.mid16(mid_out15)

        # Exit flow
        exit_out1 = self.exit_residual(mid_out16)
        exit_out2 = self.exit_conv1(exit_out1)
        exit_out3 = self.exit_conv2(exit_out2)
        out = self.exit_conv3(exit_out3)

        return out, low_level_features

ASPP:

class ASPPConv(nn.Module):
    def __init__(self, in_channels, out_channels, dilation):
      super(ASPPConv, self).__init__()
      self.conv = nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False)
      self.bn = nn.BatchNorm2d(out_channels)
      self.relu = nn.ReLU()

    def forward(self, x):
      conv = self.conv(x)
      bn = self.bn(conv)
      out = self.relu(bn)

      return out


class ASPPPooling(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super(ASPPPooling, self).__init__()
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.conv = nn.Conv2d(in_channels, out_channels, 1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU()

    def forward(self, x):
        size = x.shape[-2:]
        x = self.pool(x)
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return F.interpolate(x, size=size, mode="bilinear", align_corners=True)

class ASPP(nn.Module):
    def __init__(self, in_channels, atrous_rates, out_channels = 256):
        super(ASPP, self).__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.atrous_rates = atrous_rates

        modules = []
        modules.append(
            nn.Sequential(
                nn.Conv2d(self.in_channels, self.out_channels, 1, bias=False),
                nn.BatchNorm2d(self.out_channels),
                nn.ReLU())
        )
        rates = self.atrous_rates
        for i in range(len(rates)):
            modules.append(ASPPConv(self.in_channels, self.out_channels, rates[i]))
            modules.append(ASPPConv(self.in_channels, self.out_channels, rates[i]))
            modules.append(ASPPConv(self.in_channels, self.out_channels, rates[i]))

        modules.append(ASPPPooling(self.in_channels, self.out_channels))

        self.convs = nn.ModuleList(modules)

        self.project = nn.Sequential(
            nn.Conv2d(len(self.convs) * self.out_channels, self.out_channels, 1, bias=False),
            nn.BatchNorm2d(self.out_channels),
            nn.ReLU(),
            nn.Dropout(0.5),
        )

    def forward(self, x):
        _res = []
        for conv in self.convs:
            _res.append(conv(x))
        res = torch.cat(_res, dim=1)
        return self.project(res)

Decoder:

class Decoder(nn.Module):
    def __init__(self, in_channels, num_classes):
        super(Decoder, self).__init__()
        self.in_channels = in_channels
        self.num_classes = num_classes

        #1x1 convolution of low-level features
        self.conv1 = nn.Conv2d(self.in_channels, 48, kernel_size=1, stride=1)
        self.bn1 = nn.BatchNorm2d(48)
        self.relu1 = nn.ReLU()

        self.conv2 = nn.Conv2d(304, 256, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(256)
        self.relu2 = nn.ReLU()
        self.drop2 = nn.Dropout(0.5)

        #3x3 convolution
        self.conv3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm2d(256)
        self.relu3 = nn.ReLU()
        self.drop3 = nn.Dropout(0.1)

        self.conv4 = nn.Conv2d(256, self.num_classes, kernel_size=1, stride=1)

    def forward(self, x, low_level_features):


        #1x1 convolution of low-level features
        low_level_features = self.conv1(low_level_features)
        low_level_features = self.bn1(low_level_features)
        low_level_features = self.relu1(low_level_features)

        #concatenation
        #resize low_level_features to x size
        x = F.interpolate(x, size=low_level_features.size()[2:], mode='bilinear', align_corners=True)
        x = torch.cat((x, low_level_features), dim=1)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.drop2(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        x = self.drop3(x)
        x = self.conv4(x)

        return x

DeepLabV3+:

class DeepLabV3_plus(nn.Module):
    def __init__(self, num_classes=1, shape=(512, 512),output_stride=16):
        super(DeepLabV3_plus, self).__init__()
        self.num_classes = num_classes
        self.output_stride = output_stride
        self.shape = shape

        if self.output_stride == 16:
            oslist = [6,12,18]
        elif self.output_stride == 8:
            oslist = [12,24,36]


        #backbone
        self.backbone = Xception(os=self.output_stride)

        #ASPP
        self.aspp = ASPP(2048, oslist)

        #decoder
        self.decoder = Decoder(728, num_classes=self.num_classes)

    def forward(self, x):
        x, low_level_features = self.backbone(x)

        #ASPP
        x = self.aspp(x)

        #decoder
        x = self.decoder(x, low_level_features)

        #need resize to 512, 512
        x = F.interpolate(x, size=self.shape, mode='bilinear', align_corners=False)

        return x

train:

def train(model, train_loader, valid_loader, epoch, val_term):
    model = model.to(DEVICE[0])

    model.train()

    class_weight = []
    for i in range(n_classes):
        class_weight.append(1) if i != 0 else class_weight.append(0)

    for i in range(epoch):
      ious = 0
      epoch_loss = 0
      if not model.training:
        model.train()
      for batch_idx, (data, target) in tqdm(enumerate(train_loader), total=len(train_loader)):
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        output = model(data.float())
        target = target.permute(0,3,1,2)

        loss_output = loss(output.float(), target.float())
        iou_score = iou(output.int(), target.int(), n_classes)
        ious += iou_score
        iter_loss = loss_output.item()
        loss_output.backward()
        epoch_loss += iter_loss
        if batch_idx%(len(train_loader)//4) == 0:
            wandb.log({"loss": iter_loss})
            for j in range(n_classes):
              wandb.log({f"iou_class{j}": iou_score[j]})
            wandb.log({"miou": torch.nanmean(iou_score)})

      epoch_ious = ious/len(train_loader)
      epoch_miou = torch.nanmean(epoch_ious)
      epoch_loss = epoch_loss/len(train_loader)

      print(f"Epoch {i}/{epoch}\nLoss: {epoch_loss:.6f}\nclass IoU: {epoch_ious.numpy()}\nmIoU: {epoch_miou.numpy():.6f}")
      if i % val_term == val_term-1 or i==0:
          val_miou = validation(model, valid_loader)
          torch.save(model.state_dict(), 'epoch_{}.pth'.format(i))
      optimizer.step()
      scheduler.step(val_miou)
      
0

There are 0 best solutions below