How to train a large HuggingFace Transformers model on a v3-8 TPU VM?

76 Views Asked by At

I'm trying to fine-tune OpenLlama 3Bv2 for SequenceClassification but I've got very little experience working with TPUs. Nothing doesn't seem to be performing correctly. The first 2 batches run perfectly fine but I then receive this error: BrokenProcessPool: A process in the process pool was terminated abruptly while the future was running or pending.

I'm also generally confused about how I should go about training on the TPU and how I'd use mixed precision. In a perfect world, I'd be using something like Jax or even TensorFlow however I'm not too sure how to use those with a HF Transformers model.

Here's my current code:

import torch
import os
import pickle
from torch.utils.data import DataLoader, Dataset
from transformers import LlamaForSequenceClassification, get_linear_schedule_with_warmup
from sklearn.metrics import accuracy_score
from tqdm.auto import tqdm

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.parallel_loader as pl
import torch_xla.distributed.xla_multiprocessing as xmp

# Paths and model selection
path = "trained_paragraph_1/"
mini = False
pref = "tokenized_llama/mini/" if mini else "tokenized_llama/"
model_path = 'tokenized_llama/model'

# Training parameters
batch_size = 32
val_batch_size = 64
epochs = 1
steps_per_eval = 10000

# Optimization settings
learning_rate = 1e-6
weight_decay = 0.05
warmup_ratio = 0.1

class PreTokenizedTextDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx], dtype=torch.long) for key, val in self.encodings.items()}
        item['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item
    
# Define the map function for multiprocessing
def _mp_fn(index, flags):
    # Load the pre-tokenized training and validation datasets
    with open(f'{pref}train_tokenized.pkl', 'rb') as f:
        train_encodings, train_labels = pickle.load(f)
    train_dataset = PreTokenizedTextDataset(train_encodings, train_labels)

    with open(f'{pref}val_tokenized.pkl', 'rb') as f:
        val_encodings, val_labels = pickle.load(f)
    val_dataset = PreTokenizedTextDataset(val_encodings, val_labels)

    # Initialize the TPU device
    device = xm.xla_device()

    # Define the DataLoader for training and validation
    train_sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=True
    )
    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=train_sampler, pin_memory=True)

    val_sampler = torch.utils.data.distributed.DistributedSampler(
        val_dataset,
        num_replicas=xm.xrt_world_size(),
        rank=xm.get_ordinal(),
        shuffle=False
    )
    val_loader = DataLoader(val_dataset, batch_size=val_batch_size, sampler=val_sampler, pin_memory=True)

    # Initialize the tokenizer and model
    model = LlamaForSequenceClassification.from_pretrained(model_path, num_labels=2)

    # Move the model to the TPU
    model.to(device)
    
    # Define the optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

    # Define the learning rate scheduler with warmup
    num_training_steps = len(train_loader) * epochs
    num_warmup_steps = int(warmup_ratio * num_training_steps)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps)

    # Training loop with checkpointing and evaluation
    model.train()
    global_step = 0
    best_accuracy = 0.0
    for epoch in range(epochs):
        para_loader = pl.ParallelLoader(train_loader, [device])
        train_iterator = para_loader.per_device_loader(device)
        if xm.is_master_ordinal():
            train_iterator = tqdm(train_iterator, desc=f"Epoch {epoch+1}", unit="batch")
        for batch in train_iterator:
            optimizer.zero_grad()

            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss

            loss.backward()
            xm.optimizer_step(optimizer)

            scheduler.step()

            global_step += 1
            if xm.is_master_ordinal():
                train_iterator.set_postfix(loss=loss.item())

            if global_step % steps_per_eval == 0 and xm.is_master_ordinal():
                # Checkpointing
                xm.save(model.state_dict(), os.path.join(path, f'checkpoint-{global_step}.pt'))
                # Evaluation
                model.eval()
                total_eval_accuracy = 0
                total_eval_loss = 0
                para_loader = pl.ParallelLoader(val_loader, [device])
                eval_iterator = para_loader.per_device_loader(device)
                for batch in eval_iterator:
                    with torch.no_grad():
                        input_ids = batch['input_ids'].to(device)
                        attention_mask = batch['attention_mask'].to(device)
                        labels = batch['labels'].to(device)
                        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                        logits = outputs.logits
                        loss = outputs.loss
                        total_eval_loss += loss.item()
                        predictions = torch.argmax(logits, dim=-1)
                        total_eval_accuracy += accuracy_score(labels.cpu().numpy(), predictions.cpu().numpy())

                avg_val_accuracy = total_eval_accuracy / len(val_loader)
                avg_val_loss = total_eval_loss / len(val_loader)
                xm.master_print(f"Step {global_step}, Validation Loss: {avg_val_loss}, Validation Accuracy: {avg_val_accuracy}")

                # Save the best model
                if avg_val_accuracy > best_accuracy:
                    best_accuracy = avg_val_accuracy
                    xm.save(model.state_dict(), os.path.join(path, 'best_model.pt'))

                model.train()

    # Load the best model
    if xm.is_master_ordinal():
        model.load_state_dict(torch.load(os.path.join(path, 'best_model.pt')))

    # Save the fine-tuned model
    if xm.is_master_ordinal():
        model.save_pretrained(f'{path}fine_tuned_model')

# Start training using xmp.spawn
FLAGS = {}
xmp.spawn(_mp_fn, args=(FLAGS,), start_method='fork')

If anyone has any idea on how I can fix the problem I'm encountering and/or how to improve performance that would be great. Genuienly any tips or tricks would be fantastic. Thanks!

I've tried to implement BF16 but I couldn't get it to work.

I also tried to switch to using TensorFlow but I couldn't get the OpenLlama model to work there.

0

There are 0 best solutions below