I am trying to use torch.jit.script to script the demucs audio separation ML model. Ultimately, I'm trying to make the model compatible with coreml to use it in an iOS app
RuntimeError:
Module 'LocalState' has no attribute 'query_freqs' :
class LocalState(nn.Module):
"""Local state allows to have attention based only on data (no positional embedding),
but while setting a constraint on the time window (e.g. decaying penalty term).
Also a failed experiments with trying to provide some frequency based attention.
"""
def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
super().__init__()
assert channels % heads == 0, (channels, heads)
self.heads = heads
self.nfreqs = nfreqs
if nfreqs:
self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
def forward(self, x):
B, C, T = x.shape
heads = self.heads
indexes = torch.arange(T, device=x.device, dtype=x.dtype)
# left index are keys, right index are queries
delta = indexes[:, None] - indexes[None, :]
queries = self.query(x).view(B, heads, -1, T)
keys = self.key(x).view(B, heads, -1, T)
# t are keys, s are queries
dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
dots /= keys.shape[2]**0.5
if self.nfreqs:
periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
#Here is where the error occurs:
freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs ** 0.5
I tried:
- adding the Conv1d layers to the forward, but that wouldn't work for training.
- adding @torch.jit.script above the init
- moving the creation of self.query_freqs outside of the conditional statement
- using jit.script.attribute
- Using jit.export