What is wrong with my mask, or why it doesn't work? I try to do predict mask token based only on the first token and masked token. For this i created special create_casual_mask method that creates this mask for multihead attention. When i run it, nan tensor is returned. Pad mask and mask of masked tokens are not intersected, i full my attention mask as it is described in torch documentations for torch.nn.TransformerDecoder. Also output of attention shouldn't be empty because there are some False values in attention mask. So why it doesn't work?
import torch
from torch import nn
torch.manual_seed(0)
class LookOnFirstDecoder(nn.Module):
def __init__(self, depth, d_model, nhead, d_ff,
dropout, activation,
sent_length, n_tokens, pad_idx
):
super().__init__()
"""
:param sent_length: max length of sentence
:param n_tokens: number of tokens to use including mask and padding tokens
:param pad_idx: index of padding to don't compute the gradient
"""
self.d_model = d_model
self.nhead = nhead
self.n_tokens = n_tokens
self.emb = nn.Embedding(
num_embeddings=n_tokens,
embedding_dim=d_model,
padding_idx=pad_idx
)
self.pos_embed = nn.Parameter(
torch.zeros(1, sent_length, d_model),
requires_grad=True
)
torch.nn.init.normal_(self.pos_embed, std=.02)
self.transformer = nn.TransformerDecoder(
nn.TransformerDecoderLayer(
d_model=d_model,
nhead=nhead,
dim_feedforward=d_ff,
dropout=dropout,
activation=activation,
batch_first=True,
norm_first=True,
),
num_layers=depth,
)
self.fin_lin = nn.Linear(d_model, n_tokens)
def create_causal_mask(self, mask):
"""
The purpose is to create mask that allows all not first tokens
look only on the first token and itself
:param mask: (B, L)
:return: (B * nhead, L)
"""
mask[:, 0] = True # to depend on first token
b, l = mask.shape
batch_causal_mask = ~torch.tril(mask.unsqueeze(-1) * mask.unsqueeze(-2)) # (B, L, L)
# batch_causal_mask = torch.tril(torch.ones((b, l, l))).to("cuda") == 0
# batch_causal_mask = torch.where(batch_causal_mask, 0, float('-inf'))
print(f"Batch causal mask: \n{batch_causal_mask}")
causal_mask = (
batch_causal_mask.
unsqueeze(1). # (B, 1, L, L)
expand(b, self.nhead, l, l). # (B, nhead, L, L)
reshape(b * self.nhead, l, l) # (B * nhead, L, L)
)
return causal_mask
def forward(self, tgt, memory, is_masked_mask, is_pad_mask):
"""
:param tgt: (B, L)
:param memory: (B, L1, D)
:param is_masked_mask: (B, L) - True - mask token, False - not
:param is_pad_mask: (B, L), True - pad token, False - not
:return: tensor of shape (B, n_tokens)
"""
b, l = tgt.shape
tgt_tokens = self.emb(tgt) + self.pos_embed[:, :l].expand(b, l, self.d_model)
tgt_tokens = self.transformer(
tgt_tokens,
memory,
tgt_mask=self.create_causal_mask(is_masked_mask.clone()),
tgt_is_causal=True,
tgt_key_padding_mask=is_pad_mask
) # (B, L, D)
fin_tokens = self.fin_lin(tgt_tokens[is_masked_mask])
return fin_tokens
# my vocabulary
n_tokens = 10 # pad_idx - 9, mask_idx - 8
pad_idx = n_tokens - 1
mask_idx = n_tokens - 2
d_model = 4
nhead = 2
b, l = 3, 8
model = LookOnFirstDecoder(
depth=2,
d_model=4,
nhead=2,
d_ff=8,
dropout=0.1,
activation="gelu",
sent_length=l,
n_tokens=n_tokens,
pad_idx=pad_idx
)
memory = torch.randn(b, l, d_model)
# so i create some random tokens, without padding and mask
in_tokens = torch.randint(0, mask_idx - 1, (b, l))
# mask and paddings add manually
in_tokens[0, 6:] = pad_idx
in_tokens[0, 5] = mask_idx
in_tokens[1, 7:] = pad_idx
in_tokens[1, 4] = mask_idx
in_tokens[2, 5:] = pad_idx
in_tokens[2, 0] = mask_idx
is_masked_mask = in_tokens == mask_idx
is_pad_mask = in_tokens == pad_idx
pred = model(in_tokens, memory, is_masked_mask, in_tokens == pad_idx)
print(f"In tokens: \n{in_tokens}")
print(f"Pad mask: \n{is_pad_mask}")
print(f"Masked mask: \n{is_masked_mask}")
print(f"Pred: \n{pred}")
these is my requirements.txt
torch == 2.1.1
torchvision == 0.16.1
xformers
albumentations==1.3.1
numpy == 1.26.2
scipy == 1.11.4
scikit-learn == 1.3.2
pandas == 2.1.4
matplotlib == 3.8.2
seaborn == 0.13.0
That is my result after execution:
Batch causal mask:
tensor([[[False, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True],
[False, True, True, True, True, False, True, True],
[ True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True]],
[[False, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True],
[False, True, True, True, False, True, True, True],
[ True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True]],
[[False, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True],
[ True, True, True, True, True, True, True, True]]])
In tokens:
tensor([[2, 1, 4, 0, 3, 8, 9, 9],
[6, 4, 0, 6, 8, 0, 5, 9],
[8, 2, 5, 2, 6, 9, 9, 9]])
Pad mask:
tensor([[False, False, False, False, False, False, True, True],
[False, False, False, False, False, False, False, True],
[False, False, False, False, False, True, True, True]])
Masked mask:
tensor([[False, False, False, False, False, True, False, False],
[False, False, False, False, True, False, False, False],
[ True, False, False, False, False, False, False, False]])
Pred:
tensor([[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan],
[nan, nan, nan, nan, nan, nan, nan, nan, nan, nan]],
grad_fn=<AddmmBackward0>)