I'm encountering an issue regarding the input shape for PyTorch's MultiheadAttention. I have initialized MultiheadAttention as follows:
attention = MultiheadAttention(embed_dim=1536, num_heads=4)
The input tensors have the following shapes:
- query.shape is torch.Size([1, 1, 1536])
- Both key.shape and value.shape are torch.Size([1, 23, 1536])
However, when attempting to use these inputs, I encounter the following error:
RuntimeError Traceback (most recent call last)
Cell In[15], line 1
----> 1 _ = cal_attn_weight_embedding(attention, top_j_sim_video_embeddings_list)
File ~/main/reproduct/choi/make_embedding.py:384, in cal_attn_weight_embedding(attention, top_j_sim_video_embeddings_list)
381 print(embedding.shape)
383 # attention
--> 384 output, attn_weights = attention(thumbnail, embedding, embedding)
385 # attn_weight shape: (1, 1, j+1)
387 attn_weights = attn_weights.squeeze(0).unsqueeze(-1) # shape: (j+1, 1)
File ~/anaconda3/envs/choi_venv/lib/python3.8/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/anaconda3/envs/choi_venv/lib/python3.8/site-packages/torch/nn/modules/activation.py:1205, in MultiheadAttention.forward(self, query, key, value, key_padding_mask, need_weights, attn_mask, average_attn_weights, is_causal)
1191 attn_output, attn_output_weights = F.multi_head_attention_forward(
1192 query, key, value, self.embed_dim, self.num_heads,
...
5281 # TODO finish disentangling control flow so we don't do in-projections when statics are passed
5282 assert static_k.size(0) == bsz * num_heads, \
5283 f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
RuntimeError: shape '[1, 4, 384]' is invalid for input of size 35328
Why am I encountering this error?
The main execution environment is as follows:
- Ubuntu 20.04
- Anaconda 1.7.2
- Python 3.8.5
- VSCode 1.87.2
- PyTorch 2.0.1
Thank you for your cooperation in advance.
You need to change
attention = MultiheadAttention(embed_dim=1536, num_heads=4)to
attention = MultiheadAttention(embed_dim=1536, num_heads=4, batch_first=True)The default behavior of
batch_first=Falseis making the computation think your query batch size doesn't match your k/v batch size.