Audio Source Separation U-Net NN with complex numbers incompatibility with pytorch

39 Views Asked by At

I have a task at school to do a NN that does source separation on some audio files.I also have to apply STFT to it and use magnitude as training data Did the dataset, 400 .wav files at 48kHz, 10 sec each. And I keep having issues with pytorch's compatibility to complex numbers, it doesnt work at all.

Currently I am getting

RuntimeError: "max_pool2d" not implemented for 'ComplexFloat'

This is the code

class ComplexConv2d(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0):
        super(ComplexConv2d, self).__init__()
        self.conv_real = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
        self.conv_imag = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)

    def forward(self, x):
        real = self.conv_real(x.real) - self.conv_imag(x.imag)
        imag = self.conv_real(x.imag) + self.conv_imag(x.real)
        return torch.complex(real, imag)



class ComplexReLU(nn.Module):
    def forward(self, x):
        real_part = F.relu(x.real)
        imag_part = F.relu(x.imag)
        return torch.complex(real_part, imag_part)


class AudioUNet(nn.Module):
    def __init__(self, input_channels, start_neurons):
        super(AudioUNet, self).__init__()

        self.encoder = nn.Sequential(
            ComplexConv2d(input_channels, start_neurons, kernel_size=3, padding=1),
            ComplexReLU(),
            ComplexConv2d(start_neurons, start_neurons, kernel_size=3, padding=1),
            ComplexReLU(),
            nn.MaxPool2d(2, 2, ceil_mode=True),
            nn.Dropout2d(0.25),
            ComplexConv2d(start_neurons, start_neurons * 2, kernel_size=3, padding=1),
            ComplexReLU(),
            ComplexConv2d(start_neurons * 2, start_neurons * 2, kernel_size=3, padding=1),
            ComplexReLU(),
            nn.MaxPool2d(2, 2, ceil_mode=True),
            nn.Dropout2d(0.5),
            ComplexConv2d(start_neurons * 2, start_neurons * 4, kernel_size=3, padding=1),
            ComplexReLU(),
            ComplexConv2d(start_neurons * 4, start_neurons * 4, kernel_size=3, padding=1),
            ComplexReLU(),
            nn.MaxPool2d(2, 2, ceil_mode=True),
            nn.Dropout2d(0.5),
            ComplexConv2d(start_neurons * 4, start_neurons * 8, kernel_size=3, padding=1),
            ComplexReLU(),
            ComplexConv2d(start_neurons * 8, start_neurons * 8, kernel_size=3, padding=1),
            ComplexReLU(),
            nn.MaxPool2d(2, 2, ceil_mode=True),
            nn.Dropout2d(0.5),
            ComplexConv2d(start_neurons * 8, start_neurons * 16, kernel_size=3, padding=1),
            ComplexReLU(),
            ComplexConv2d(start_neurons * 16, start_neurons * 16, kernel_size=3, padding=1)
        )

        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(start_neurons * 16, start_neurons * 8, kernel_size=3, stride=2, padding=1,
                               output_padding=1),
            ComplexConv2d(start_neurons * 16, start_neurons * 8, kernel_size=3, padding=1),
            ComplexReLU(),
            nn.Dropout2d(0.5),
            nn.ConvTranspose2d(start_neurons * 8, start_neurons * 4, kernel_size=3, stride=2, padding=1,
                               output_padding=1),
            ComplexConv2d(start_neurons * 8, start_neurons * 4, kernel_size=3, padding=1),
            ComplexReLU(),
            nn.Dropout2d(0.5),
            nn.ConvTranspose2d(start_neurons * 4, start_neurons * 2, kernel_size=3, stride=2, padding=1,
                               output_padding=1),
            ComplexConv2d(start_neurons * 4, start_neurons * 2, kernel_size=3, padding=1),
            ComplexReLU(),
            nn.Dropout2d(0.5),
            nn.ConvTranspose2d(start_neurons * 2, start_neurons, kernel_size=3, stride=2, padding=1, output_padding=1),
            ComplexConv2d(start_neurons * 2, start_neurons, kernel_size=3, padding=1),
            ComplexReLU(),
            nn.Dropout2d(0.5),
            ComplexConv2d(start_neurons, 1, kernel_size=1)
        )

    def forward(self, x):
        x = x.unsqueeze(1)  # Assuming the channel dimension is the first dimension

        # Process through the encoder
        encoder_output = self.encoder(x)

        # Process through the decoder
        decoder_output = self.decoder(encoder_output)

        # Combine the encoder and decoder outputs
        output = encoder_output + decoder_output

        # Assuming you want to return the real part of the output
        return output.squeeze(1)

Now, I have the NN model,did a ComplexConv function as long as a ComplexRelu, but I keep getting error because I am using complex numbers and I am just circling around in errors, i tried with chatgpt but it resolves one error and then there is another one. Can you please tell me if I am on the right path and maybe how could I fix the complex number incompatibility problem?

0

There are 0 best solutions below