i trained a GAN network with cross Att, but there raises "invalid size" while I use the same images and same network in training when testing:
File "ToDayGAN-master/models/networks.py", line 110, in forward
- (torch.eye(W, W).unsqueeze(0).repeat(B * W, 1, 1) * sys.maxsize)
RuntimeError: shape '[1, 72, 127, 127]' is invalid for input of size 373248
the origin code block of cross Att is
class Criss_Cross_Attention(nn.Module):
def __init__(self, in_c):
super(Criss_Cross_Attention, self).__init__()
self.in_c = in_c
self.query = Conv1x1(self.in_c, self.in_c // 8)
self.key = Conv1x1(self.in_c, self.in_c // 8)
self.value = Conv1x1(self.in_c, self.in_c)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
B, C, W, H = x.size()
query = self.query(x)
key = self.key(x)
value = self.value(x)
query_h = query.permute(0, 3, 2, 1).contiguous().view(B * W, H, C // 8)
query_w = query.permute(0, 2, 3, 1).contiguous().view(B * H, W, C // 8)
key_h = key.permute(0, 3, 1, 2).contiguous().view(B * W, C // 8, H)
key_w = key.permute(0, 2, 1, 3).contiguous().view(B * H, C // 8, W)
value_h = value.permute(0, 3, 1, 2).contiguous().view(B * W, C, H)
value_w = value.permute(0, 2, 1, 3).contiguous().view(B * H, C, W)
energy_h = (
torch.bmm(query_h, key_h).contiguous().view(B, W, H, H)
- (torch.eye(W, W).unsqueeze(0).repeat(B * W, 1, 1) * sys.maxsize)
.contiguous()
.view(B, W, H, H)
.cuda()
).permute(0, 2, 1, 3)
energy_w = torch.bmm(query_w, key_w).contiguous().view(B, H, W, W)
attention_map = F.softmax(torch.cat([energy_h, energy_w], 3), 3)
attention_map_h = (
attention_map[:, :, :, 0:H]
.permute(0, 2, 1, 3)
.contiguous()
.view(B * W, H, H)
.permute(0, 2, 1)
)
attention_map_w = (
attention_map[:, :, :, H : H + W]
.contiguous()
.view(B * H, W, W)
.permute(0, 2, 1)
)
out_h = (
torch.bmm(value_h, attention_map_h)
.contiguous()
.view(B, W, -1, H)
.permute(0, 2, 3, 1)
)
out_w = (
torch.bmm(value_w, attention_map_w)
.contiguous()
.view(B, H, -1, W)
.permute(0, 2, 1, 3)
)
out = out_h + out_w
out = out.view(B, C, W, H)
out = self.gamma * out + x
return out
I tried some lower resolution images(256256) and it worked. But I changed the size of images(9001600), it raised the error