I am working on a vision transformer based model architectuer where i get a Q and K shape of (1, 4, 2097152, 32) (We are working with 3D images). When i try to calculate the product of Q and K, i am getting a memory out error.
For the multiplication i used both torch.matmul and torch.bmm. Nether of them solved the error.
OutOfMemoryError: CUDA out of memory. Tried to allocate 65536.00 GiB. GPU 0 has a total capacity of 15.77 GiB of which 13.47 GiB is free. Process 56962 has 2.30 GiB memory in use. Of the allocated memory 2.00 GiB is allocated by PyTorch, and 0 bytes is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation. See documentation for Memory Management (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)
This is the code i'm working on. Module is build to performe the cross attention.
class CPA3D(nn.Module):
def __init__(self, dim, dim_, heads=2, dim_head=64, dropout=0., p_one = True):
super().__init__()
self.dim = dim
self.heads = heads
self.scale = (dim_head ** -0.5)
self.attend = nn.Softmax(dim=-1)
self.to_q = nn.Linear(dim_, heads * dim_head, bias=False)
self.to_k = nn.Linear(dim, heads * dim_head, bias=False)
self.to_v = nn.Linear(dim, heads * dim_head, bias=False)
self.p_one = p_one
self.to_out = nn.Sequential(
nn.Linear(heads * dim_head, dim),
nn.Dropout(dropout)
)
def forward(self, x1, x2):
B, D, H, W, C = x1.shape
x1_flat = x1.view(B, -1, C) # B, D*H*W, C
x2_flat = x2.view(B, -1, C) # B, D*H*W, C
k = self.to_k(x1_flat)
v = self.to_v(x1_flat)
q = self.to_q(x2_flat)
q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
k = rearrange(k, 'b n (h d) -> b h n d', h=self.heads)
v = rearrange(v, 'b n (h d) -> b h n d', h=self.heads)
dots = torch.bmm(q, k.permute(0,1,3,2)) * self.scale. # line
attn = self.attend(dots)
out = torch.bmm(attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
if self.p_one:
out = self.to_out(out)
out = out.view(B, D, H, W, -1)
return out
For the information i am working on colab pro with A100 GPU.
can anyone suggest me a solution to resolve this situation?