Changing the Attention Layer of a Transformer

306 Views Asked by At

I would like to test my own formulation of the attention mechanism for a transformer. To that end I would like to find an existing pre trained transformer that is easy to read through and that uses not too large of a dataset. I just want to take that model and replace the code of the attention mechanism with my own.

I have already tried using https://www.google.com/url?sa=t&source=web&rct=j&opi=89978449&url=https://github.com/hkproj/pytorch-transformer&ved=2ahUKEwiKjpDF4q2DAxX-WUEAHXDMA8sQFnoECBIQAQ&usg=AOvVaw2HTqVDl_mTK23mqiuQ7_wH but for some reason I get an error after the first epoch with my code. The epoch takes about 30 mins so I'm struggling to debug to see where I've gone wrong.

I've also tried looking at the hugging face github repository but there are so many options and I'm lost there. I'm very new to deep learning and this is my first time trying to code a model.

Thanks

2

There are 2 best solutions below

0
nairbv On

If you only want to test a change to the attention mechanism, you might want something simpler than a full encoder-decoder transformer.

There are three main kinds of transformer models -- encoder-decoder, encoder-only, and decoder-only.

The original transformer (implemented at the link you shared) is a full encoder decoder. In the original Vaswani et al paper this was trained on a translation task. This makes sense because it encodes the information in the input sentence then uses the decoder to generate a sentence in the target language.

An example of encoder-only would be BERT, typically trained on an MLM task.

What we generally refer to as "decoder-only" models are GPT-style models. IMO these are the simplest to train, since they're typically trained to predict next-token on arbitrary text. Some simple implementations of decoder-only models include:

A common pattern is to use a trivial tokenizer that tokenizes characters, use a simple single-document training text like "the complete works of Shakespeare," and work on a small variant of the model with fewer layers/dimensions. I find I can typically train a minimal GPT-style model this way on a consumer-grade laptop (macbook) and get it producing meaningful words within an hour or two.

0
Muhammed Yunus On

Since it's your first time coding up a model, I think it could be useful to work with a simplified attention model on a dataset like MNIST (handwritten digits 0-9). That gives you a manageable dataset to work with and results that are easy to visualise.

You could use the basic attention mechanism as a baseline, and compare results against your modifications.

The data and code below show how to implement a simple attention block, used for classifying digits. The complete code is at the end.

Load and visualise MNIST data.

enter image description here

The attention mechanism assumes you have broken each sample down into patches/tokens. The custom MakePatches layer takes in a 2D image, and cuts it up into 2x2 patches.

enter image description here

Each patch comprises 4 pixels in this case. We can optionally 'convert' the 4 into some other embedding size using the embedding layer EmbedPatches.

The patches have a meaningful spatial arrangement since they're derived from an image. The layer PositionalEncoding layer adds information that helps index the patches.

The code includes an implementation of a dot-product self-attention block SimpleAttention, though it could be adapted for cross-attention with some modifications.

We define the model by chaining the blocks together, with a final classification head:

model = nn.Sequential(
    MakePatches(patch_width=patch_width),
    EmbedPatches(token_size=patch_n_pixels, embedding_size=embedding_size),
    PositionalEncoding(n_patches=n_patches, embedding_size=embedding_size),
    SimpleAttention(embedding_size=embedding_size),
    nn.Flatten(),
    nn.Linear(n_channels * n_patches * embedding_size, n_classes)
)

After training it ends up with about 92% accuracy on the validation set. You can train for fewer epochs if you want to iterate more quickly. I haven't done much tuning for this net.

enter image description here

enter image description here

Visualise some classification results:

enter image description here

Finally, we can hook into the attention layer and observe how the net attends on a class-by-class basis:

enter image description here

For each class, the net seems to hone in on the negative space around a digit, and based on its assessment of that negative space, decide whether it belongs to the class or not. The point here is that you can access the attention matrix, which will be useful when analysing your own implementation.

The learnt positional encoding matrix:

enter image description here

It has a somewhat binary characteristic, emphasised in the image by clipping the values.

Full data and code are below.

Imports, load & visualise data.

import torch
from torch import nn
from torch.utils.data import DataLoader

from sklearn.datasets import fetch_openml
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt
import numpy as np

np.random.seed(0)

X_orig, y_orig = fetch_openml('mnist_784', return_X_y=True, as_frame=False)

X = X_orig[:60_000]
y = y_orig[:len(X)].astype(int)

X_trn, X_val, y_trn, y_val = train_test_split(
    X, y, stratify=y, shuffle=True, test_size=0.1, random_state=0
)

scaler = StandardScaler().fit(X_trn)
X_trn = scaler.fit_transform(X_trn)
X_val = scaler.fit_transform(X_val)

#Reshape
img_width = int(X_trn.shape[1] ** 0.5)
n_channels = 1
n_classes = np.unique(y).size

X_trn = X_trn.reshape(-1, n_channels, img_width, img_width)
X_val = X_val.reshape(-1, n_channels, img_width, img_width)

#View some data
f, axs = plt.subplots(ncols=9, nrows=2, figsize=(8, 2), layout='tight')
axs = axs.flatten()

for i, ax in enumerate(axs):
    ax.imshow(X_trn[i, 0], cmap='binary', vmin=-1, vmax=1)
    ax.axis('off')
    ax.set_title(f'y={y_trn[i]}', fontsize=8)
f.suptitle('Training data and labels', fontsize=10)

#To tensors
X_trn, X_val = [torch.tensor(X).float() for X in [X_trn, X_val]]
y_trn, y_val = [torch.tensor(y).long() for y in [y_trn, y_val]]

patch_width = 2

#Layer that breaks any given sample image into a sequence of patches
class MakePatches(nn.Module):
    def __init__(self, patch_width, final_flatten=True):
        super().__init__()
        self.patch_width = patch_width
        self.final_flatten = final_flatten
        
    def forward(self, x):
        B, C, img_h, img_w = x.shape
        
        patch_h = patch_w = self.patch_width        
        n_patch_cols = n_patch_rows = img_w // patch_w
        
        patches = (
            x
            .reshape(B, C, n_patch_rows, patch_h, n_patch_cols, patch_w)
            .swapdims(3, 4)  #B C patch_row patch_col row_px col_px
        )
        
        if self.final_flatten:
            #2D patch indices flattened to 1D
            patches = patches.reshape(B, C, n_patch_rows * n_patch_cols, patch_h * patch_w)
            #> B C patch_idx patch_elements

        return patches
    
#Visualise on one sample
patches = MakePatches(patch_width=patch_width, final_flatten=False)(X_trn[2:3, 0:1]).squeeze()

f, axs = plt.subplots(nrows=patches.shape[0], ncols=patches.shape[1], figsize=(5, 5))
for row in range(axs.shape[0]):
    for col in range(axs.shape[1]):
        ax = axs[row, col]
        ax.imshow(patches[row, col], cmap='binary', vmin=-2, vmax=2)
        ax.set_xticks([])
        ax.set_yticks([])
f.suptitle('A sample made into patches \n'
           f'patch_width={patch_width}px (thus token_size={patch_width**2})',
           fontsize=10)
f.set_size_inches(3.5, 3.5)

#Layer that embeds patches into a user-defined size
# Could skip this layer and just feed in the raw pixel values
embedding_size = 5

class EmbedPatches(nn.Module):
    def __init__(self, token_size, embedding_size):
        super().__init__()
        
        self.embedding_matrix = nn.Parameter(
            torch.empty(token_size, embedding_size)
        )
        nn.init.kaiming_normal_(self.embedding_matrix)
    
    def forward(self, x):
        return x @ self.embedding_matrix
#Test:
EmbedPatches(8, embedding_size)(torch.ones(2, 1, 12, 8)).shape

#Layer that adds positional encoding (learnable)
class PositionalEncoding(nn.Module):
    def __init__(self, embedding_size, max_sequence_len=512):
        super().__init__()
        
        self.pos_encoding_matrix = nn.Parameter(
            torch.empty(1, 1, max_sequence_len, embedding_size)
        )
        nn.init.kaiming_normal_(self.pos_encoding_matrix)
    
    def forward(self, x):
        return x + self.pos_encoding_matrix[:, :, :x.shape[2], :]

n_patches = (img_width // patch_width) ** 2
patch_n_pixels = patch_width ** 2

#Test:
PositionalEncoding(embedding_size=embedding_size)(torch.ones(1, 1, n_patches, embedding_size));

#Simple attention block
# In:  [B C patch_idx embedding_size]
# Out: [B C patch_idx embedding_size]
class SimpleAttention(nn.Module):
    def __init__(self, embedding_size, n_heads=1, model_size=8):
        super().__init__()
        
        self.model_size = model_size
        self.n_heads = n_heads
        
        empty = torch.empty(1, 1, n_heads, embedding_size, model_size)
        self.Q = nn.Parameter(empty)
        self.K = nn.Parameter(empty)
        self.V = nn.Parameter(empty)
        
        for matrix in [self.Q, self.K, self.V]:
            nn.init.kaiming_uniform_(matrix)
        
        self.collapse_heads = nn.Linear(n_heads * model_size, embedding_size)
        
    def forward(self, x):
        B, C, n_patches, _ = x.shape
        x = x.unsqueeze(2) #add head dimension
        
        queries, keys, values = [x @ matrix for matrix in [self.Q, self.K, self.V]]
        #> B C heads patch_idx model_size
    
        attention = nn.Softmax(dim=-1)( queries @ keys.swapdims(3, 4) / self.model_size ** 0.5)
        #> B C heads q_idx k_idx
    
        attn_weighted_values = values.swapdims(3, 4) @ attention.swapdims(3, 4)
        #> B C heads model_size patch_idx

        #Combine heads        
        concat_heads = attn_weighted_values.reshape(B, C, self.n_heads * self.model_size, n_patches)        
        merged_heads = self.collapse_heads(concat_heads.swapdims(2, 3))
        #> B C patch_idx token_size
        
        #store as instance variable so we can hook in
        self.attention = attention
        
        enriched_x = x.squeeze(dim=2) + merged_heads
        return enriched_x

#Test:
SimpleAttention(embedding_size=embedding_size, n_heads=3)(torch.ones(2, 1, n_patches, embedding_size));

Define and train model

#Define model
np.random.seed(0)
torch.manual_seed(0)

model = nn.Sequential(
    MakePatches(patch_width=patch_width),
    EmbedPatches(token_size=patch_n_pixels, embedding_size=embedding_size),
    PositionalEncoding(embedding_size=embedding_size),
    SimpleAttention(embedding_size=embedding_size),
    nn.Flatten(), #B C patch_idx embedding_size --> B <all>
    nn.Linear(n_channels * n_patches * embedding_size, n_classes)
)
print('Model size is', sum([p.numel() for p in model.parameters()]) / 1e3, 'K params')

optimizer = torch.optim.Adam(model.parameters())
loss_fn = nn.functional.cross_entropy

#Batchify data using DataLoader
batch_size = 32
loader_common_params = dict(batch_size=batch_size, shuffle=True, num_workers=7)
trn_loader = DataLoader(list(zip(X_trn, y_trn)), **loader_common_params)
val_loader = DataLoader(list(zip(X_val, y_val)), **loader_common_params)

@torch.no_grad()
def compute_accuracy(model, loader, max_samples=np.inf):
    model.eval()
    
    cum_correct = 0
    cum_samples = 0
    for batch_X, batch_y in loader:
        logits = model(batch_X)
        
        cum_correct += (logits.argmax(dim=1) == batch_y).sum()
        cum_samples += len(batch_X)
        
        if cum_samples >= max_samples:
            break
    return cum_correct / cum_samples * 100

#Train
n_epochs = 10

from collections import defaultdict
metrics = defaultdict(list)

for epoch in range(n_epochs):
    model.train()
    epoch_cum_loss = 0
    
    for minibatch_num, (batch_X, batch_y) in enumerate(trn_loader):
        logits = model(batch_X)
        loss = loss_fn(logits, batch_y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_cum_loss += loss.item()

        if (minibatch_num + 1) % 100 == 0:
            print(
                f'[epoch {epoch + 1:>3d}/{n_epochs:>3d}]'
                f'[minibatch {minibatch_num + 1:>4d}/{len(trn_loader):>4d}] '
                f'minibatch loss {loss.item():>5.3f}',
                end='\r'
            )
    
    #Record metrics per epoch
    metrics['trn_loss'].append(epoch_cum_loss / len(trn_loader))
    metrics['trn_acc'].append( compute_accuracy(model, trn_loader, max_samples=2_000) )
    metrics['val_acc'].append( compute_accuracy(model, val_loader) )
    
    print(
        f'[epoch {epoch + 1:>3d}/{n_epochs:>3d}]'
        f'[minibatch {minibatch_num + 1:>4d}/{len(trn_loader):>4d}] '
        f'avg loss {metrics["trn_loss"][-1]:>5.3f}',
        f'| trn acc {metrics["trn_acc"][-1]:>6.3f}%   '
        f'val acc {metrics["val_acc"][-1]:>6.3f}%'
    )

#Plot loss and accuracy
f, ax = plt.subplots(figsize=(5, 2))
ax.plot(metrics['trn_loss'], label='loss')
ax.set_xlabel('epoch')
ax.set_ylabel('loss')

ax2 = ax.twinx()
ax2.plot(metrics['trn_acc'], ls='--', label='trn accuracy')
ax2.plot(metrics['val_acc'], ls='--', label='val accuracy')
ax2.set_ylabel('accuracy (%)')

f.legend(ncol=3)

Visualise results and attention.

model.eval()

#Predictions for some validation samples

f, axs = plt.subplots(nrows=2, ncols=9, figsize=(8, 2), layout='tight')
axs = axs.flatten()

with torch.no_grad(): yhat = model(X_val[:len(axs)]).argmax(dim=1)

for i, ax in enumerate(axs):
    ax.imshow(X_val[i, 0], cmap='binary', vmin=-1, vmax=1)
    ax.axis('off')
    ax.set_title(
        f'y={y_val[i].item()} |' + '$\hat{y}$=' + str(yhat[i].item()),
        color='tab:green' if y_val[i]==yhat[i] else 'tab:red',
        fontsize=8, fontweight='bold'
    )

#
# Visualisation of attention for each class
#

#Hook into attention layer & store the attention matrix
attentions = []
hook = model[3].register_forward_hook(
    lambda self, input, output: attentions.append(self.attention.numpy())
)

X_subset = X_trn[:10_000]
y_subset = y_trn[:len(X_subset)]

with torch.no_grad(): model(X_subset)
hook.remove()

attentions = np.array(attentions[0])

avg_attention_perclass = np.empty((n_classes, n_patches))
for clas in range(n_classes):
    #B C heads patch_q patch_k
    avg_attention_perclass[clas] = (
        attentions[y_subset==clas] #attention per sample
        .mean(axis=2, keepdims=True) #average over heads       
        .mean(axis=0, keepdims=True) #average over samples
        .mean(axis=3) #average over q = the average k
        .squeeze()
    )
    
f, axs = plt.subplots(ncols=5, nrows=2, figsize=(8, 3), layout='tight')
axs = axs.flatten()
for i, ax in enumerate(axs):
    ax.imshow(avg_attention_perclass[i].reshape([int(n_patches**0.5)] * 2),
              cmap='plasma')
    ax.axis('off')
    ax.set_title(f'class {i}', fontsize=9)
plt.show()

#View encoding matrix
plt.imshow(model[2].pos_encoding_matrix[0, 0, :n_patches, :].detach(), aspect='auto', origin='lower',
           cmap='plasma', vmin=-0.1, vmax=0.1, interpolation='none')
plt.colorbar(ticks=[-0.1, 0, 0.1])
plt.xlabel('embedding dim')
plt.ylabel('patch/token index')
plt.gcf().set_size_inches(2, 4)
plt.title('Learnt positional encoding matrix', fontsize=11);