What should collator do exactly?

27 Views Asked by At

Suppose we have an audio classification task (AudioMNIST).

My pipeline and other pipelines I’ve seen consist of the next steps:

  1. Read the dataset (the data samples).
  2. Do the base transforms (merge the audio channels, change the bitrate, etc).
  3. Split the dataset into the train one, the test one, etc.
  4. Do the main transforms (different for the train and the test) such as the augmentation.
  5. Batch (along with the sampling).
  6. Pad/Truncate the batch samples.
  7. Do the forward pass with the batch.
  8. <…>

I saw the scheme:

  • Dataset or a subclass - pp. 1., 2., 3., 4.
  • Collator - p. 6.

Either:

  • Dataset or a subclass - p. 1.
  • somebody else - pp. 2., 3., 4.
  • Collator - p. 6.

Or:

  • Dataset or a subclass - p. 1.
  • somebody else - p. 3.
  • Collator - pp. 2., 4., 6.

What should the collator do and what shouldn’t? (The main question.) What is the correct scheme?

1

There are 1 best solutions below

0
Karl On

You've tagged this with pytorch, so I'll give the pytorch answer.

Pytorch data utils has a Dataset and a DataLoader. tl;dr, the Dataset handles loading a single example, while the DataLoader handles batching and any bulk processing.

The Dataset has two methods, __len__ for determining the number of items in the dataset and __getitem__ for loading a single item.

class MyDataset(Dataset):
    def __init__(self):
        ...

    def __len__(self):
        ...

    def __getitem__(self, index):
        ...

The DataLoader is passed a list of outputs from the Dataset (ie batch_input = [dataset.__getitem__(i) for i in idxs]). The batch input is sent to the collate_fn of the DataLoader.

def my_collate_fn(batch):
    ...

dataloader = DataLoader(my_dataset, batch_size, collate_fn=my_collate_fn)

In terms of thinking about what to do where, the Dataset should handle loading single examples. The Dataset will be called in parallel, so tasks that are CPU-bound should go in the Dataset. Loading from disk (if applicable) is also typically done in the Dataset.

The collate_fn handles converting a list of outputs from your Dataset into whatever format your model wants. Since the DataLoader deals with a batch of data, it can be more efficient to apply batch processing steps. Stacking tensors, padding to length, generating masks or other bulk tensor ops work well in the collate_fn.

In general, think of the Dataset as running multi-process on single examples, while the DataLoader running a single-process on a batch of examples.