why testing would raise the "invalid size" while i use the same images and same network in training

44 Views Asked by At

i trained a GAN network with cross Att, but there raises "invalid size" while I use the same images and same network in training when testing:

File "ToDayGAN-master/models/networks.py", line 110, in forward
    - (torch.eye(W, W).unsqueeze(0).repeat(B * W, 1, 1) * sys.maxsize)
 RuntimeError: shape '[1, 72, 127, 127]' is invalid for input of size 373248

the origin code block of cross Att is

class Criss_Cross_Attention(nn.Module):
    def __init__(self, in_c):
        super(Criss_Cross_Attention, self).__init__()
        self.in_c = in_c

        self.query = Conv1x1(self.in_c, self.in_c // 8)
        self.key = Conv1x1(self.in_c, self.in_c // 8)
        self.value = Conv1x1(self.in_c, self.in_c)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        B, C, W, H = x.size()
        query = self.query(x)
        key = self.key(x)
        value = self.value(x)
        query_h = query.permute(0, 3, 2, 1).contiguous().view(B * W, H, C // 8)
        query_w = query.permute(0, 2, 3, 1).contiguous().view(B * H, W, C // 8)
        key_h = key.permute(0, 3, 1, 2).contiguous().view(B * W, C // 8, H)
        key_w = key.permute(0, 2, 1, 3).contiguous().view(B * H, C // 8, W)
        value_h = value.permute(0, 3, 1, 2).contiguous().view(B * W, C, H)
        value_w = value.permute(0, 2, 1, 3).contiguous().view(B * H, C, W)

        energy_h = (
            torch.bmm(query_h, key_h).contiguous().view(B, W, H, H)
            - (torch.eye(W, W).unsqueeze(0).repeat(B * W, 1, 1) * sys.maxsize)
            .contiguous()
            .view(B, W, H, H)
            .cuda()
        ).permute(0, 2, 1, 3)
        energy_w = torch.bmm(query_w, key_w).contiguous().view(B, H, W, W)

        attention_map = F.softmax(torch.cat([energy_h, energy_w], 3), 3)
        attention_map_h = (
            attention_map[:, :, :, 0:H]
            .permute(0, 2, 1, 3)
            .contiguous()
            .view(B * W, H, H)
            .permute(0, 2, 1)
        )
        attention_map_w = (
            attention_map[:, :, :, H : H + W]
            .contiguous()
            .view(B * H, W, W)
            .permute(0, 2, 1)
        )
        out_h = (
            torch.bmm(value_h, attention_map_h)
            .contiguous()
            .view(B, W, -1, H)
            .permute(0, 2, 3, 1)
        )
        out_w = (
            torch.bmm(value_w, attention_map_w)
            .contiguous()
            .view(B, H, -1, W)
            .permute(0, 2, 1, 3)
        )
        out = out_h + out_w
        out = out.view(B, C, W, H)
        out = self.gamma * out + x

        return out

I tried some lower resolution images(256256) and it worked. But I changed the size of images(9001600), it raised the error

0

There are 0 best solutions below