How to normalize Imagenet dataset with pytroch?

264 Views Asked by At

I am trying to write a simple code to teach resnet50 on ImageNet dataset. I don't quite get why normalization does not work. When I use this transformation as data augmentation:

train_transforms = transforms.Compose([
    transforms.Resize((224, 224), antialias=True),
    transforms.RandomCrop(180),
    transforms.Resize((224, 224), antialias=True),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.15),
    transforms.RandomApply([transforms.GaussianBlur(3, sigma=(0.1, 2.0))], p=0.5),
    transforms.ToTensor()
])

and tan I check the results, it seem to be OK. It is a light augmentation. For example, this waffle iron:

Original waffle iron

Looks LIke this:

enter image description here

However if I do the exact same thing, but add normalization, like this:

train_transforms = transforms.Compose([
        transforms.Resize((224, 224), antialias=True),
        transforms.RandomCrop(180),
        transforms.Resize((224, 224), antialias=True),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.15),
        transforms.RandomApply([transforms.GaussianBlur(3, sigma=(0.1, 2.0))], p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

The waffle iron looks like this:

enter image description here

I tried to put the normalization in different parts of the transforms.Compose, but it didn't work in other ways. In some constellations to throws an error (as it is defined on tensors), or it just ruins the picture in other ways, like:

enter image description here

This happens when I try this order:

train_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std),
    transforms.Resize((224, 224), antialias=True),
    transforms.RandomCrop(180),
    transforms.Resize((224, 224), antialias=True),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.15),
    transforms.RandomApply([transforms.GaussianBlur(3, sigma=(0.1, 2.0))], p=0.5),
])

How could I make it work, and what is the problem here? The rest of the code is just standard PyTorch learning process, nothing fancy.

Update:

To answer some questions in the comments, I used these values as mean and std:

mean = [0.485, 0.456, 0.406]
std = [0.229, 0.224, 0.225]

I found on several articles and blogs that these are the usual ImageNet values. The relevant part of plotting:

# Helper function to convert tensor to PIL Image
to_pil = ToPILImage()

idx = indices[i].item()
label = labels[i]

# Define the paths to save original and augmented images
original_img_path = os.path.join(debug_dir, f"{global_id_counter}_{interpretable_label[i]}_original.png")
augmented_img_path = os.path.join(debug_dir, f"{global_id_counter}_{interpretable_label[i]}_augmented.png")

# Convert tensor to PIL Image and save
original_pil_img = to_pil(original_img[i].cpu())  # Convert tensor to PIL Image
original_pil_img.save(original_img_path)

# Get the augmented image from inputs
augmented_img_np = inputs_np[i].transpose(1, 2, 0)

# Convert from float to uint8
augmented_img_np = (augmented_img_np * 255).astype(np.uint8)

# Convert to PIL image and save
augmented_img = Image.fromarray(augmented_img_np)
augmented_img.save(augmented_img_path)

I save the images for later comparison.

2

There are 2 best solutions below

1
Yakov Dan On

Assuming you've used reasonable values for mean and std, there's nothing wrong with your code. Normalization does not preserve the visual appearance of images. Typically you can expect small changes in color

0
Karl On

Image pixel values are expected to be floats between 0-1 or integers between 0-255. When you normalize an image, you get floats outside the normal range. When you view the normalized image with PIL, it clips the image values to the expected range, causing the visual issues.

If you want to visualize the transformed image, you need to denormalize using the same mean and std values to get things back into the expected range.