sympy replace Derivative with a customized funciton resulted error

39 Views Asked by At

consider the following setup

from sympy import Function
eta = Function('eta')
from sympy import *
s,t,u = symbols("s t u",real = True)

error_=pi*exp(t)*exp(-I*u)*Subs(Derivative(eta(s), s), s, t + I*u)*Subs(Derivative(eta(s), s), s, t - I*u)/(2*eta(t - I*u)*eta(t + I*u))

one defined a function to simplify the operation

def Manual_Derivative( *args):   
    print(args)
    arg1, arg2 = args
    if len(arg2)==1:
        d_o = 1
    else:
        d_o = arg2[1]
    return eta( d_o)

where

error_.replace(Derivative, Manual_Derivative)

returned an error

(eta(s), (s, 1))
(eta(s), (s, 1))

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[103], line 1
----> 1 error_.replace(Derivative, Manual_Derivative)

File ~\anaconda3\lib\site-packages\sympy\core\basic.py:1666, in Basic.replace(self, query, value, map, simultaneous, exact)
   1663             expr = v
   1664     return expr
-> 1666 rv = walk(self, rec_replace)
   1667 return (rv, mapping) if map else rv

File ~\anaconda3\lib\site-packages\sympy\core\basic.py:1643, in Basic.replace.<locals>.walk(rv, F)
   1641 newargs = tuple([walk(a, F) for a in args])
   1642 if args != newargs:
-> 1643     rv = rv.func(*newargs)
   1644     if simultaneous:
   1645         # if rv is something that was already
   1646         # matched (that was changed) then skip
   1647         # applying F again
   1648         for i, e in enumerate(args):

File ~\anaconda3\lib\site-packages\sympy\core\cache.py:72, in __cacheit.<locals>.func_wrapper.<locals>.wrapper(*args, **kwargs)
     69 @wraps(func)
     70 def wrapper(*args, **kwargs):
     71     try:
---> 72         retval = cfunc(*args, **kwargs)
     73     except TypeError as e:
     74         if not e.args or not e.args[0].startswith('unhashable type:'):

File ~\anaconda3\lib\site-packages\sympy\core\operations.py:98, in AssocOp.__new__(cls, evaluate, _sympify, *args)
     95 if len(args) == 1:
     96     return args[0]
---> 98 c_part, nc_part, order_symbols = cls.flatten(args)
     99 is_commutative = not nc_part
    100 obj = cls._from_args(c_part + nc_part, is_commutative)

File ~\anaconda3\lib\site-packages\sympy\core\mul.py:703, in Mul.flatten(cls, seq)
    700 c_part = _new
    702 # order commutative part canonically
--> 703 _mulsort(c_part)
    705 # current code expects coeff to be always in slot-0
    706 if coeff is not S.One:

File ~\anaconda3\lib\site-packages\sympy\core\mul.py:35, in _mulsort(args)
     33 def _mulsort(args):
     34     # in-place sorting of args
---> 35     args.sort(key=_args_sortkey)

File ~\anaconda3\lib\site-packages\sympy\core\basic.py:281, in Basic.compare(self, other)
    279     c = l.compare(r)
    280 else:
--> 281     c = (l > r) - (l < r)
    282 if c:
    283     return c

TypeError: unsupported operand type(s) for -: 'StrictGreaterThan' and 'StrictLessThan'

However, if we simply use

Manual_derivative = Function("Manual_derivative")
error_.replace(Derivative, Manual_derivative )

the error would not occur

Is there a way to fix it? so that Manual_Derivative could be carried out?


Updates:

The

class Manual_Derivative_cl(Function):
   @classmethod
   def eval(cls, *args):
      ...

also didn't work, however, when I used

class Manual_Derivative_cl(Function):
    def _eval_evalf(self, prec):
        ...

and used the .evalf(), it somewhat worked. There might be a bug somewhere in the .replace() function.

0

There are 0 best solutions below