Is it possible to add a trainable filter after an autoencoder?

156 Views Asked by At

So I’m building a denoiser with an autoencoder. The idea is that before computing my loss (after the autoencoder), I apply an empirical wiener filter to a texture map of the image and add it back to my autoencoder output (adding back ‘lost detail’). I’ve coded this filter with PyTorch.

My first attempt worked by adding the filter to the end of my autoencoder’s forward function. I can train this network and it backpropagates through my filter in training. However, if I print my network, the filter is not listed, and torchsummary doesn’t include it when calculating parameters.

This has me thinking that I am only training the autoencoder and my filter is filtering the same way every time and not learning.

Is what I’m trying to do possible?

Below is my Autoencoder:

class AutoEncoder(nn.Module):
"""Autoencoder simple implementation """
def __init__(self):
    super(AutoEncoder, self).__init__()
    # Encoder
    # conv layer
    self.block1 = nn.Sequential(
        nn.Conv2d(1, 48, 3, padding=1),
        nn.Conv2d(48, 48, 3, padding=1),
        nn.MaxPool2d(2),
        nn.BatchNorm2d(48),
        nn.LeakyReLU(0.1)

    )
    self.block2 = nn.Sequential(
        nn.Conv2d(48, 48, 3, padding=1),
        nn.MaxPool2d(2),
        nn.BatchNorm2d(48),
        nn.LeakyReLU(0.1)
    )
    self.block3 = nn.Sequential(
        nn.Conv2d(48, 48, 3, padding=1),
        nn.ConvTranspose2d(48, 48, 2, 2, output_padding=1),
        nn.BatchNorm2d(48),
        nn.LeakyReLU(0.1)
    )
    self.block4 = nn.Sequential(
        nn.Conv2d(96, 96, 3, padding=1),
        nn.Conv2d(96, 96, 3, padding=1),
        nn.ConvTranspose2d(96, 96, 2, 2),
        nn.BatchNorm2d(96),
        nn.LeakyReLU(0.1)
    )
    self.block5 = nn.Sequential(
        nn.Conv2d(144, 96, 3, padding=1),
        nn.Conv2d(96, 96, 3, padding=1),
        nn.ConvTranspose2d(96, 96, 2, 2),
        nn.BatchNorm2d(96),
        nn.LeakyReLU(0.1)
    )
    self.block6 = nn.Sequential(
        nn.Conv2d(97, 64, 3, padding=1),
        nn.BatchNorm2d(64),
        nn.Conv2d(64, 32, 3, padding=1),
        nn.BatchNorm2d(32),
        nn.Conv2d(32, 1, 3, padding=1),
        nn.LeakyReLU(0.1)
    )

    # self.blockNorm = nn.Sequential(
    #     nn.BatchNorm2d(1),
    #     nn.LeakyReLU(0.1)
    # )

def forward(self, x):
    # torch.autograd.set_detect_anomaly(True)
    # print("input: ", x.shape)
    pool1 = self.block1(x)
    # print("pool1: ", pool1.shape)
    pool2 = self.block2(pool1)
    # print("pool2: ", pool2.shape)
    pool3 = self.block2(pool2)
    # print("pool3: ", pool3.shape)
    pool4 = self.block2(pool3)
    # print("pool4: ", pool4.shape)
    pool5 = self.block2(pool4)
    # print("pool5: ", pool5.shape)
    upsample5 = self.block3(pool5)
    # print("upsample5: ", upsample5.shape)
    concat5 = torch.cat((upsample5, pool4), 1)
    # print("concat5: ", concat5.shape)
    upsample4 = self.block4(concat5)
    # print("upsample4: ", upsample4.shape)
    concat4 = torch.cat((upsample4, pool3), 1)
    # print("concat4: ", concat4.shape)
    upsample3 = self.block5(concat4)
    # print("upsample3: ", upsample3.shape)
    concat3 = torch.cat((upsample3, pool2), 1)
    # print("concat3: ", concat3.shape)
    upsample2 = self.block5(concat3)
    # print("upsample2: ", upsample2.shape)
    concat2 = torch.cat((upsample2, pool1), 1)
    # print("concat2: ", concat2.shape)
    upsample1 = self.block5(concat2)
    # print("upsample1: ", upsample1.shape)
    concat1 = torch.cat((upsample1, x), 1)
    # print("concat1: ", concat1.shape)
    output = self.block6(concat1)

    t_map = x - output

    for i in range(4):
        tensor = t_map[i, :, :, :]                 # Take each item in batch separately. Could account for this in Wiener instead

        tensor = torch.squeeze(tensor)              # Squeeze for Wiener input format

        tensor = wiener_3d(tensor, 0.05, 10)        # Apply Wiener with specified std and block size
        tensor = torch.unsqueeze(tensor, 0)         # unsqueeze to put back into block
        t_map[i, :, :, :] = tensor                  # put back into block

    filtered_output = output + t_map
    return filtered_output

The for loop at the end is to apply the filter to each image in the batch. I get that this isn’t parallelisable so if anyone has ideas for this, I’d appreciate it. I can post the ‘wiener 3d()’ filter function if that helps, just want to keep the post short.

I’ve tried to define a custom layer class with the filter inside it but I got lost very quickly.

Any help would be greatly appreciated!

1

There are 1 best solutions below

3
Jan On BEST ANSWER

If all you want is to turn your Wiener filter into a module, the following would do:

class WienerFilter(T.nn.Module):
    def __init__(self, param_a=0.05, param_b=10):
        super(WienerFilter, self).__init__()
        # This can be accessed like any other member via self.param_a
        self.register_parameter("param_a", T.nn.Parameter(T.tensor(param_a)))
        self.param_b = param_b

    def forward(self, input):
        for i in range(4):
            tensor = input[i]                
            tensor = torch.squeeze(tensor)
            tensor = wiener_3d(tensor, self.param_a, self.param_b)
            tensor = torch.unsqueeze(tensor, 0)
            input[i] = tensor 
        return input  

You can apply this by adding a line

self.wiener_filter = WienerFilter()

in the init function of your AutoEncoder.

in the forward then you all it by replacing the for loop with

filtered_output = output + self.wiener_filter(t_map)

Torch knows that the wiener_filter module is a member module so it will list the module if you print your AutoEncoder's modules.

If you want to parallelize your wiener filter, you need to do that in PyTorch's terms, meaning using its operations on tensors. Those operations are implemented in a parallel fashion.