I am using a numba jitclass and would like to make a transformation on the key whenever it is not a slice (but I want to keep the slice functionality).
Question: How can I?
To give a little context, I would rather write tensor[coord] than tensor[tensor_to_formalseries(coord, tensor.dim)] and I also like the condensed tensor[:key] more than tensor.formal_series[:key].
Below are 3 examples that work in pure python and don't as jitclasses.
import numpy as np
from numba import njit
from numba.experimental import jitclass
from numba.core.types import int64, SliceType
@njit
def tensor_to_formalseries(coordinate: int, dim=2):
key = coordinate * 2 # some crazy stuff with coordinate
return key
@jitclass(spec={"formal_series": int64[:], "dim": int64})
class Tensor1:
def __init__(self, dim=2):
self.dim = dim
self.formal_series = np.arange(10)
def __getitem__(self, key):
if isinstance(key, (slice, SliceType)):
# tried with SliceLiteral, slice2_type, slice3_type from numba.core.types
print("key is a slice")
return self.formal_series[key]
else:
print("key is not a slice")
return self.formal_series[tensor_to_formalseries(key, self.dim)]
tensor = Tensor1()
print(tensor[:3])
print(tensor[2])
"""
Rejected as the implementation raised a specific error:
NumbaTypeError: isinstance() does not support variables of type "slice<a:b>".
"""
@jitclass(spec={"formal_series": int64[:], "dim": int64})
class Tensor2:
def __init__(self, dim=2):
self.dim = dim
self.formal_series = np.arange(10)
def __getitem__(self, key):
if isinstance(self.formal_series[key], np.ndarray):
print("key is a slice")
return self.formal_series[key]
else:
print("key is not a slice")
return self.formal_series[tensor_to_formalseries(key, self.dim)]
tensor = Tensor2()
print(tensor[:3])
print(tensor[2])
"""
Use of unsupported NumPy function 'numpy.ndarray' or unsupported use of the function.
"""
@jitclass(spec={"formal_series": int64[:], "dim": int64})
class Tensor3:
def __init__(self, dim=2):
self.dim = dim
self.formal_series = np.arange(10)
def __getitem__(self, key):
try:
len(self.formal_series[key])
print("key is a slice")
return self.formal_series[key]
except Exception:
# should be a TypeError, but jitclass doesn't like it: UnsupportedError: Exception matching is limited to <class 'Exception'>
print("key is not a slice")
return self.formal_series[tensor_to_formalseries(key, self.dim)]
tensor = Tensor3()
print(tensor[:3])
"""
Overload of function 'mul': File: <numerous>: Line N/A.
With argument(s): '(slice<a:b>, int64)':
"""
print(tensor[2])
"""
Overload of function 'len': File: <numerous>: Line N/A.
With argument(s): '(int64)':
No match.
"""
As of numba 0.56,
isinstance()is not supported inside a numba class. Source: numba jitclass documentationIs it really needed to numba compile the "is this a slice" check? Most likely your code spent the most time on the transformation part, which means that the transformation part is where you need to focus your optimisation efforts.
What i would do, is to detect the type of the key with regular python code, then use a njit compiled numba function to do the actual transformation.
I would also refrain from using the experimental numba class type. You are probably better off using a regular python class that simply calls a numba compiled function.