Isinstance slice in Numba jitclass __getitem__

189 Views Asked by At

I am using a numba jitclass and would like to make a transformation on the key whenever it is not a slice (but I want to keep the slice functionality).

Question: How can I?

To give a little context, I would rather write tensor[coord] than tensor[tensor_to_formalseries(coord, tensor.dim)] and I also like the condensed tensor[:key] more than tensor.formal_series[:key].

Below are 3 examples that work in pure python and don't as jitclasses.

import numpy as np
from numba import njit
from numba.experimental import jitclass
from numba.core.types import int64, SliceType

@njit
def tensor_to_formalseries(coordinate: int, dim=2):
    key = coordinate * 2 # some crazy stuff with coordinate
    return key
@jitclass(spec={"formal_series": int64[:], "dim": int64})
class Tensor1:
    def __init__(self, dim=2):
        self.dim = dim
        self.formal_series = np.arange(10)

    def __getitem__(self, key):
        if isinstance(key, (slice, SliceType)):
            # tried with SliceLiteral, slice2_type, slice3_type from numba.core.types
            print("key is a slice")
            return self.formal_series[key]
        else:
            print("key is not a slice")
            return self.formal_series[tensor_to_formalseries(key, self.dim)]

tensor = Tensor1()
print(tensor[:3])
print(tensor[2])
"""
Rejected as the implementation raised a specific error:
  NumbaTypeError: isinstance() does not support variables of type "slice<a:b>".
"""
@jitclass(spec={"formal_series": int64[:], "dim": int64})
class Tensor2:
    def __init__(self, dim=2):
        self.dim = dim
        self.formal_series = np.arange(10)

    def __getitem__(self, key):
        if isinstance(self.formal_series[key], np.ndarray):
            print("key is a slice")
            return self.formal_series[key]
        else:
            print("key is not a slice")
            return self.formal_series[tensor_to_formalseries(key, self.dim)]

tensor = Tensor2()
print(tensor[:3])
print(tensor[2])
"""
Use of unsupported NumPy function 'numpy.ndarray' or unsupported use of the function.
"""
@jitclass(spec={"formal_series": int64[:], "dim": int64})
class Tensor3:
    def __init__(self, dim=2):
        self.dim = dim
        self.formal_series = np.arange(10)

    def __getitem__(self, key):
        try:
            len(self.formal_series[key])
            print("key is a slice")
            return self.formal_series[key]
        except Exception:
            # should be a TypeError, but jitclass doesn't like it: UnsupportedError: Exception matching is limited to <class 'Exception'>
            print("key is not a slice")
            return self.formal_series[tensor_to_formalseries(key, self.dim)]

tensor = Tensor3()
print(tensor[:3])
"""
Overload of function 'mul': File: <numerous>: Line N/A.
  With argument(s): '(slice<a:b>, int64)':
"""

print(tensor[2])
"""
Overload of function 'len': File: <numerous>: Line N/A.
  With argument(s): '(int64)':
 No match.
"""
2

There are 2 best solutions below

4
Rafnus On

As of numba 0.56, isinstance() is not supported inside a numba class. Source: numba jitclass documentation

Is it really needed to numba compile the "is this a slice" check? Most likely your code spent the most time on the transformation part, which means that the transformation part is where you need to focus your optimisation efforts.

What i would do, is to detect the type of the key with regular python code, then use a njit compiled numba function to do the actual transformation.

I would also refrain from using the experimental numba class type. You are probably better off using a regular python class that simply calls a numba compiled function.

0
Shep Bryan On

Evidently you can pass slices into numba functions and use them as you would a regular slice, you just can't ask if an object is a slice. A pretty hacky way to get around this is to first use the slice on an array, then evaluate the output. For your application, I assume that you really only are checking whether the input is a slice or an int. In this case you can use this code

@nb.jit(nopython=True)
def _checkifslice(obj, maxsize=100):
    array = np.zeros(maxsize)
    test = array[obj]
    if len(np.shape(test)) == 0:
        return False
    else:
        return True

Typically, if you are using this within a class, then you can pick an optimal maxsize based on the attributes of the class.

Clearly this isn't the most optimal solution, but it is a workaround until numba makes slices easier to work.