How can I make interpolation function that works with torch gradient

58 Views Asked by At

Are there any methods for making Down/Upscaling function that has gradient flow? I want to make this because of backpropagation for training Downscaling Factor Generation & Faster RCNN.

I'm designing a computer vision deep learning flow with Pytorch.

Preliminary:

It's consists of Two parts : Preprocessing and Vision Task. As Preprocessing part, I created a downscaling factor generator model and an image Down/Upscaling module. In the Vision Task part, I forward the preprocessed image and to do the object detections.

Train flow looks like this :: img input -> Downscaling Factor Generation(img forward) -> img downscaling with downscaling factor -> img upscaling with 1/downscaling factor -> Vision Task(Faster RCNN, upscaled img) -> step backward.

Question:

As i was designing this flow, i got stuck at the Down/Upscaling module. I tried to use Pytorch's F.interpolate(img, scale_factor,,,) function, but this blocked the flow of gradients. I put original image and downscaling factor(from model w/ gradeint) into F.interpolate function, but grad_fn is disappearing.

I tried to made custom interpolation function like this,

def bilinear_interpolate(self, img, scale_factor):
    print('img, scale_factor :',img,scale_factor)
    n, c, h, w = img.size()
    new_h, new_w = int(h * scale_factor), int(w * scale_factor) 
    device = img.device

    h_scale = torch.linspace(0, h-1, new_h, device=device)
    w_scale = torch.linspace(0, w-1, new_w, device=device)

    grid_h, grid_w = torch.meshgrid(h_scale, w_scale)

    h_floor = grid_h.floor().long()
    h_ceil = h_floor + 1
    h_ceil = h_ceil.clamp(max=h-1)

    w_floor = grid_w.floor().long()
    w_ceil = w_floor + 1
    w_ceil = w_ceil.clamp(max=w-1)
    print('h_floor,h_floor, h_ceil, w_floor, w_ceil :',h_floor, h_ceil, w_floor, w_ceil)

    tl = img[:, :, h_floor, w_floor]
    tr = img[:, :, h_floor, w_ceil]
    bl = img[:, :, h_ceil, w_floor]
    br = img[:, :, h_ceil, w_ceil]

    h_frac = grid_h - h_floor.to(device)
    w_frac = grid_w - w_floor.to(device)

    # bilinear interpolation
    top = tl + (tr - tl) * w_frac
    bottom = bl + (br - bl) * w_frac
    interpolated_img = top + (bottom - top) * h_frac

    return interpolated_img

but it doesn't work because of int/float transformation and sort of variable assignments.

1

There are 1 best solutions below

0
Muhammed Yunus On

The BilinearInterpolation layer below performs scaling whilst preserving gradient flow. It wraps F.interpolate, and the gradient function at the output is <UpsampleBilinear2DBackward0>.

Output:

"grad_fn" of z_scaled is: <UpsampleBilinear2DBackward0 object...>

enter image description here

Interpolation layer:

class BilinearInterpolation(nn.Module):
    def __init__(self, scale_factor):
        super(BilinearInterpolation, self).__init__()
        self.scale_factor = scale_factor

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        new_height = int(height * self.scale_factor)
        new_width = int(width * self.scale_factor)

        # Perform bilinear interpolation
        interpolated = F.interpolate(x, size=(new_height, new_width), mode='bilinear', align_corners=True)
        return interpolated

Reproducible example:

import torch
from torch import nn
import torch.nn.functional as F

class BilinearInterpolation(nn.Module):
    def __init__(self, scale_factor):
        super(BilinearInterpolation, self).__init__()
        self.scale_factor = scale_factor

    def forward(self, x):
        batch_size, channels, height, width = x.size()
        new_height = int(height * self.scale_factor)
        new_width = int(width * self.scale_factor)

        # Perform bilinear interpolation
        interpolated = F.interpolate(x, size=(new_height, new_width), mode='bilinear')
        return interpolated

#
# Test data
#
import numpy as np
xx, yy = np.meshgrid(*[np.linspace(-1, 1)] * 2)
z = np.sin(xx)**2 + np.cos(yy)**2
z = torch.tensor(z).float()
z = z[None, None, ...]
z.requires_grad = True

#View original data
import matplotlib.pyplot as plt
plt.contourf(z[0, 0, :, :].detach(), cmap='YlGnBu')
plt.text(x=9, y=42, s=f'original tensor\n{list(z.shape)}', fontweight='bold')

#Scale
scale_factor = 2
z_scaled = BilinearInterpolation(scale_factor=scale_factor)(z)

#View scaled data
plt.contourf(z_scaled[0, 0, ...].detach(), cmap='YlGnBu', zorder=0)
plt.text(x=28, y=91, s=f'{scale_factor}x interpolated tensor\ndims={list(z_scaled.shape)}', fontweight='bold')
plt.xlabel('x')
plt.ylabel('y')
plt.gcf().set_size_inches(5, 5)

print('"grad_fn" of z_scaled is:', z_scaled.grad_fn)