Getting adjoint state of solution in Gekko

39 Views Asked by At

After solving an optimal control problem in Gekko (IMODE = 6) is there any way to access or reconstruct the adjoint state p ? Since the documentation does not provide any resource for this, I hopping that there is an way to retrieve some information that may led to reconstruction of the adjoint state.

A bonus question, is there any optimal control solver (with python API) that returns the adjoint state?

1

There are 1 best solutions below

2
John Hedengren On BEST ANSWER

There is the m.options.SENSITIVITY = 1 option in gekko to produce sensitivity.txt in the run directory m.path. However, this only works if there are zero DOF for simulation or an optimization problem that has zero DOF when the MV status is off.

The other alternative is to add the adjoint equations to the problems. Here is an example dynamic optimization problem.

import numpy as np
from gekko import GEKKO
m = GEKKO()
nt = 101
m.time = np.linspace(0,1,nt)
x = m.Var(value=1)
u = m.Var(value=0,lb=-1,ub=1)
m.Equation(x.dt()==-x + u)
m.Minimize(x**2)
m.options.IMODE = 6
m.solve()

Adding the adjoint equation gives the value of lam, as the sensitivity of the objective function to changes in x. The transversality condition for this problem is that lam(T)=0 (final time constraint) and the initial condition is calculated. This is achieved with fixed_initial=False when declaring lam and setting m.fix_final(lam,0) to fix the final value at zero.

costate results

import numpy as np
from gekko import GEKKO
import matplotlib.pyplot as plt

m = GEKKO()
nt = 101
m.time = np.linspace(0, 1, nt)
x = m.Var(value=1)
u = m.Var(value=0, lb=-1, ub=1)
lam = m.Var(value=0,fixed_initial=False) # Adjoint variable
m.fix_final(lam,0)
m.Equation(x.dt() == -x + u)
m.Equation(lam.dt() == 2*x + lam)  # Adjoint state equation
m.Minimize(x**2)
m.options.IMODE = 6
m.solve(disp=False)

# Plotting
plt.figure(figsize=(7,4))
plt.plot(m.time, x.value, 'b-', lw=2, label=r'$x$')
plt.plot(m.time, u.value, 'r--', lw=2, label=r'$u$')
plt.plot(m.time, lam.value, 'g-.', lw=2, label=r'$\lambda$')
plt.xlabel('Time'); plt.legend()
plt.grid(); plt.tight_layout()
plt.show()