Find the best parameters for custom equation

71 Views Asked by At

I'm trying to find the best parameters for a custom equation using the optimize function, but results are not okay and I suppose that I'm doing something wrong, and it is better to use minimize method. Also, I've heard that Nelder-Mead algorithm should works good for this task, but I have not enough experience to implement it in practice.

My current code:

from scipy.optimize import curve_fit
        
xdata = [0.221,0.23,0.24,0.242,0.233,0.21,0.171,0.221,0.231,0.237,0.227,0.213,0.209,0.209,0.196,0.207,0.213,0.218,0.187,0.196,0.203,0.205,0.219,0.224,0.216,0.205,0.2,0.184,0.169]
zdata = [317,316.6,316.2,315.8,315.4,315,314.6,312.6,312.1,311.7,311.3,310.9,310.5,310.1,301.2,300.8,300.4,300,296.7,296.3,295.9,291,290.6,290.2,289.8,289.4,289,288.6,270.8]
ydata = [0.211,0.199,0.192,0.197,0.212,0.246,0.329,0.252,0.238,0.231,0.251,0.282,0.296,0.305,0.231,0.211,0.203,0.195,0.248,0.234,0.224,0.183,0.163,0.156,0.161,0.167,0.162,0.168,0.253]
  
def func(X,a,b,c,d):
   x,z = X
   return( (a + b * x) + (1 - (a + b * x)) * (c / z) ** (1.0 / d) )

popt, pcov = curve_fit(func, (xdata,zdata), ydata, p0=[-0.1,-0.5,0.01,5], maxfev=1000000)

x, z - input variables. Both are always positive values. y - output, desired, variable (always >= 0 and <= 1), which I want to predict based on x and z.

a, b, c, d - parameters to be fitted, using optimize or minimize functions.

1

There are 1 best solutions below

4
Reinderien On

Your initial guess isn't very good, and you can't tell unless you plot it out. So provide a better initial guess, and you'll see from the plot that it requires a negative c - which only behaves well under exponentiation if you pull its sign out of the expression.

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colorbar import Colorbar
from matplotlib.tri import TriContourSet
from scipy.optimize import curve_fit


Params = tuple[float, float, float, float]


def func(xz: np.ndarray, a: float, b: float, c: float, d: float) -> np.ndarray:
    x, z = xz
    abx = a + b*x
    return abx + (1 - abx) * np.sign(c) * (np.abs(c)/z) ** (1/d)


def fit(x: np.ndarray, y: np.ndarray, z: np.ndarray) -> tuple[Params, Params]:
    p0 = (1, -5e-3, -2.5e+5, 1)
    popt, _ = curve_fit(f=func, p0=p0, xdata=(x, z), ydata=y)
    return p0, popt


def plot(xdata: np.ndarray, ydata: np.ndarray, zdata: np.ndarray, p0: Params, popt: Params) -> plt.Figure:
    fig: plt.Figure = plt.figure()
    ax_fit:   plt.Axes = fig.add_subplot(2, 2, 3)
    ax_exp:   plt.Axes = fig.add_subplot(2, 2, 1, sharex=ax_fit)
    ax_guess: plt.Axes = fig.add_subplot(2, 2, 2, sharey=ax_exp)

    vmin = ydata.min()
    vmax = ydata.max()
    contour_set: TriContourSet = ax_exp.tricontourf(xdata, zdata, ydata, vmin=vmin, vmax=vmax)
    ax_guess.tricontourf(xdata, zdata, func((xdata, zdata), *p0), vmin=vmin, vmax=vmax)
    ax_fit.tricontourf(xdata, zdata, func((xdata, zdata), *popt), vmin=vmin, vmax=vmax)

    bar: Colorbar = fig.colorbar(contour_set, ax=[ax_exp, ax_guess, ax_fit])
    bar.set_label('y')
    ax_exp.set_title('Experiment')
    ax_guess.set_title('Initial guess')
    ax_fit.set_title('Fit')
    ax_exp.set_ylabel('z')
    ax_fit.set_xlabel('x')
    ax_fit.set_ylabel('z')

    return fig


def main() -> None:
    xdata = np.array((
        0.221, 0.23, 0.24, 0.242, 0.233, 0.21, 0.171, 0.221, 0.231, 0.237, 0.227, 0.213, 0.209, 0.209, 0.196, 0.207,
        0.213, 0.218, 0.187, 0.196, 0.203, 0.205, 0.219, 0.224, 0.216, 0.205, 0.2, 0.184, 0.169))
    ydata = np.array((
        0.211, 0.199, 0.192, 0.197, 0.212, 0.246, 0.329, 0.252, 0.238, 0.231, 0.251, 0.282, 0.296, 0.305, 0.231, 0.211,
        0.203, 0.195, 0.248, 0.234, 0.224, 0.183, 0.163, 0.156, 0.161, 0.167, 0.162, 0.168, 0.253))
    zdata = np.array((
        317, 316.6, 316.2, 315.8, 315.4, 315, 314.6, 312.6, 312.1, 311.7, 311.3, 310.9, 310.5, 310.1, 301.2, 300.8,
        300.4, 300, 296.7, 296.3, 295.9, 291, 290.6, 290.2, 289.8, 289.4, 289, 288.6, 270.8))

    p0, popt = fit(xdata, ydata, zdata)
    print(popt)
    plot(xdata, ydata, zdata, p0, popt)
    plt.show()


if __name__ == '__main__':
    main()
[ 9.99844947e-01 -8.76454497e-04 -1.05139069e+05  7.56371094e-01]

fit