Pytorch reserving way more data than needed

27 Views Asked by At

I'm trying to finetune a sentencetransformer. The issue is that I'm running into an OOM-error (I'm using google-cloud to train the model).

I keep getting that pytorch reserves ~13GB (theres ~14GB) available thus theres no room for any batch.

If I try to calcuate the actual memory used, it's around 1.3GB

from sentence_transformers import models
model_name = "alexandrainst/scandi-nli-large"
word_embedding_model = models.Transformer(model_name, max_seq_length=512)
pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension())
model = SentenceTransformer(modules=[word_embedding_model, pooling_model])
model.to("cuda")

param_size = 0
for param in model.parameters():
    param_size += param.nelement() * param.element_size()
buffer_size = 0
for buffer in model.buffers():
     buffer_size += buffer.nelement() * buffer.element_size()
size_all_mb = (param_size + buffer_size) / 1024 ** 2
print('model size: {:.3f}MB'.format(size_all_mb)) # ~1300
torch.cuda.memory_reserved()/(1024**2) # ~1300

I have tried to call torch.cuda.empty_cache() and set os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" but the same error occurs.

Isn't there a way to don't make pytorch reserve memory (or atleast reduce it) but just use the memory which is needed?

1

There are 1 best solutions below

1
Karl On

You are calculating the memory required for just the weights of the model - this is a fraction of the total memory required for training.

When you train the model, you also have memory allocation for the model's activations, gradients, and optimizer state.

Pytorch doesn't "reserve more memory than needed" - you just need that much for what you are trying to do.

To reduce the memory required for fine-tuning, you can look into the following:

  • use mixed precision training
  • reducing batch size and using gradient accumulation
  • gradient checkpointing
  • fine-tuning only the final layer of the model
  • using efficient fine-tuning methods like LORA

Since you are using Huggingface, this will be useful