For my thesis I have to run some large calculations in python, which I'm trying to speed up using numba. I have the following code:
import numpy as np
from numba import njit
# %%
#@njit
def full_model(bdotyr,T,Tav,nYear,dt):
rhoi = 917.00 # ice density
Eg = 4.24E4 # Activation eneryg for grain growth, from A10
kg = 1.3E-7 # grain growth constant, from A10
R = 8.314 # ideal gas constant
g = 9.81665 # gravity constant
rho0 = 315. # fresh polar snow density
rsq0 = (0.001*0.3)**2 # initial grain size of 0.3 mm
# derived constants
nval = int(nYear*86400*365.25/dt)
A = np.array([9.268e-9, 8.869e-14])
Ea = np.array([42.4e3, 49e3])
m = np.array([2, 1.4])
n = np.array([1, 1.8])
class Max_state_type:
@njit
def __init__(self, nTime, rho0, rsq0):
self.nTime = nTime
self.iTime = 0
self.rho = np.zeros(nTime,np.float64)
self.rho[0] = rho0
self.rsq = np.zeros(nTime,np.float64)
self.rsq[0] = rsq0
self.s = np.zeros(nTime,np.float64)
self.sigma = np.zeros(nTime,np.float64)
self.z = np.zeros(nTime,np.float64)
@njit
def dr2dt(T):
return kg*np.exp(-Eg/(R*T))
@njit
def f12(A, Ea, T, s, m, n, r):
return A*np.exp(-Ea/(R*T))*r**(-m)*s**n
@njit
def theta(rho):
return (1+np.exp(-0.07*(rho - 550)))**(-1)
@njit
def drhodt(A, Ea, T, s, m, n, rsq, rho):
f1 = f12(A[0], Ea[0], T, s, m[0], n[0], np.sqrt(rsq))
f2 = f12(A[1], Ea[1], T, s, m[1], n[1], np.sqrt(rsq))
return rho*(1-theta(rho))* f1 + rho*theta(rho) * f2
@njit
def get_s(rho, rhoi, sigma):
return (rhoi/rho - 1)*rhoi/rho * sigma
@njit
def eval_model(state, A, Ea):
# using RK4,
bdot = bdotyr/(86400*365.25)
state.s[0] = get_s(state.rho[0], rhoi,state.sigma[0])
for i in range(1, nval):
#First sigma
state.sigma[i] = state.sigma[i-1] + g*bdot*dt
state.s[i] = get_s(state.rho[i-1], rhoi, state.sigma[i])
# r^2:
q1 = dr2dt(T)
state.rsq[i] = state.rsq[i-1] + dt*q1
# now rho:
state.rho[i] = state.rho[i-1]+ dt* drhodt(A,Ea,T,state.s[i],m,n,state.rsq[i],state.rho[i-1])
state.z[i] = state.z[i-1] + bdot/state.rho[i]*dt
state.rsq = np.sqrt(state.rsq)
return state
state = Max_state_type(nval, rho0, rsq0)
state = eval_model(state, A, Ea)
return state
#some constants
bdotyr = np.arange(100,1000,100)
Temp = np.arange(273.15-30,273.15,30/9)
rhoCrit = 550
Tav = 273.15-28
nYear = 100
dt = 8640
@njit
def get_zCrit(state,rhoCrit):
for (i, item) in enumerate(state.rho):
if np.abs(item-rhoCrit)<np.abs(state.rho[i+1]-rhoCrit):
return state.z[i]
@njit
def get_plotData(bdotyr,Temp,Tav,nYear,dt,rhoCrit):
zCrit = np.zeros((len(bdotyr),len(Temp)),float64)
for i in range(len(bdotyr)):
for j in range(len(Temp)):
state = full_model(bdotyr[i],Temp[j],Tav,nYear,dt)
zCrit[i,j] = get_zCrit(state,rhoCrit)
return bdotyr, Temp, zCrit
X_Max, Y_Max, Z_Max = get_plotData(bdotyr,Temp,Tav,nYear,dt,rhoCrit)
This code doesn't work, I get the error message
TypingError: non-precise type pyobject
when I try to call the function get_plotData(). I don't really see why this happens, since all input variables have a specified numba type. I already tried specifying a signature for the function get_plotData(), this didn't change anything unfortunately. Moreover I specified the type of local variable zCrit, this also didn't work.What am I missing?