torch.jit.script Error for nn.Module attributes in the __init__: Module has no attribute

46 Views Asked by At

I am trying to use torch.jit.script to script the demucs audio separation ML model. Ultimately, I'm trying to make the model compatible with coreml to use it in an iOS app

RuntimeError: 
Module 'LocalState' has no attribute 'query_freqs' :
class LocalState(nn.Module):
    """Local state allows to have attention based only on data (no positional embedding),
    but while setting a constraint on the time window (e.g. decaying penalty term).

    Also a failed experiments with trying to provide some frequency based attention.
    """
    def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
        super().__init__()
        assert channels % heads == 0, (channels, heads)
        self.heads = heads
        self.nfreqs = nfreqs

        if nfreqs:
            self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)

    def forward(self, x):
        B, C, T = x.shape
        heads = self.heads
        indexes = torch.arange(T, device=x.device, dtype=x.dtype)
        # left index are keys, right index are queries
        delta = indexes[:, None] - indexes[None, :]

        queries = self.query(x).view(B, heads, -1, T)
        keys = self.key(x).view(B, heads, -1, T)
        # t are keys, s are queries
        dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
        dots /= keys.shape[2]**0.5
        if self.nfreqs:
            periods = torch.arange(1, self.nfreqs + 1, device=x.device,         dtype=x.dtype)
            freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
     #Here is where the error occurs:
            freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5

I tried:

  1. adding the Conv1d layers to the forward, but that wouldn't work for training.
  2. adding @torch.jit.script above the init
  3. moving the creation of self.query_freqs outside of the conditional statement
  4. using jit.script.attribute
  5. Using jit.export
0

There are 0 best solutions below