Examples

Example use of the ORMATEX package.

Example 1: Forced Lotka-Volterra

This example demonstrates defining a non-autonomous ODE system and integrating the system with an exponential integration routine in ORMATEX. The system is given by:

\[ \begin{align}\begin{aligned}\frac{dx}{dt} = \alpha x - \beta x y\\\frac{dy}{dt} = \delta x y - \gamma y - h_y(t)\\h_y(t) = 0.4 (\mathrm{sin}(0.2 t) + 1)\\\alpha = \beta = \delta = \gamma = 1\end{aligned}\end{align} \]

The result is compared against a dopri5 integration routine provided by Diffrax.

# Example using python bindings to ormatex rust integrators
#
# NOTE: Before executing this example demo.  Ensure the ormatex
# package is installed by running the following:
#
#     maturin develop --release
#
import numpy as np
import jax
from jax import numpy as jnp
from ormatex_py.integrate_wrapper import integrate
from ormatex_py.ormatex import PySysWrapped, integrate_wrapper_rs
from ormatex_py.ode_sys import OdeSys, OdeSysNp, MatrixLinOp
jax.config.update("jax_enable_x64", True)


@jax.jit
def f_pred_hunt(t):
    return 0.4*(jnp.sin(t*0.2)+1.0)

# Define the nonautonomous ODE system
class LotkaVolterra(OdeSys):
    alpha: float
    beta: float
    delta: float
    gamma: float
    ft_scale: float

    def __init__(self, a=1.0, b=1.0, d=1.0, g=1.0, **kwargs):
        self.alpha = a
        self.beta = b
        self.delta = d
        self.gamma = g
        self.ft_scale = kwargs.get("ft_scale", 1.0)
        super().__init__()

    @jax.jit
    def _frhs(self, t, x, **kwargs):
        # hunter populations are known functions of time
        pred_hunt = f_pred_hunt(t) * self.ft_scale
        # pred prey time derivatives
        prey_t = self.alpha * x[0] - self.beta * x[0] * x[1]
        pred_t = self.delta * x[0] * x[1] - self.gamma * x[1] - pred_hunt*x[1]
        res = jnp.asarray([prey_t, pred_t])
        return jax.device_get(res).flatten()

def run_model(dt, nsteps, method="exprb2_rs", tol_fdt=1.0e-6, ft_scale=1.0):
    # Step the system forward
    t0 = 0.0
    y0 = np.array([0.1, 0.2])
    res = integrate(LotkaVolterra(ft_scale=ft_scale), y0, t0, dt, nsteps,
                    method=method, tol_fdt=tol_fdt)
    y0 = jnp.array(y0.flatten())
    # Check against dopri5 in diffrax
    res_expected = integrate(LotkaVolterra(ft_scale=ft_scale), y0, t0, dt, nsteps,
                             method="dopri5")
    return np.asarray(res.t), np.asarray(res.y), res_expected.t, res_expected.y

if __name__ == "__main__":
    import matplotlib.pylab as plt
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument("-method", help="Integration method. Valid methods are: "
                        "exprb2_rs, exprb3_rs, epi3_rs, exprb2, exprb3, epi3. "
                        "Methods ending in _rs are rust impl. Others are python/JAX impl.",
                        type=str, default="epi3_rs")
    parser.add_argument("-ft_scale", help="Forcing term scale", type=float, default=1.0)
    parser.add_argument("-dt", help="time step size", type=float, default=0.05)
    parser.add_argument("-nsteps", help="number of steps", type=int, default=1000)
    parser.add_argument("-tol_fdt", help="Nonautonomous system check threshold", type=float, default=1.0e-6)
    args = parser.parse_args()
    t_out, y_out, t_true, y_true = run_model(
            args.dt, args.nsteps, args.method, args.tol_fdt, ft_scale=args.ft_scale)
    # Visualize results
    print(y_out)
    plt.figure()
    plt.plot(t_out, y_out[:, 0], label='prey')
    plt.plot(t_out, y_out[:, 1], label='pred')
    plt.plot(t_true, y_true[:, 0], ls='--', label='prey true')
    plt.plot(t_true, y_true[:, 1], ls='--', label='pred true')
    mae = np.mean(np.abs(y_out[:, 0] - y_true[:, 0]))
    plt.plot(t_out, f_pred_hunt(t_out)*args.ft_scale, label='predator hunters')
    plt.grid(ls='--')
    plt.title("Method: %s, MAE err: %0.3e" % (args.method, mae))
    plt.legend()
    plt.savefig("ormatex_rspy_lv_%s.png" % (str(args.method)))
    plt.close()

After installing the ORMATEX package and executing the following from the examples directory:

python ex_ormatex_rspy.py -method epi3

Gives the following result:

_images/ormatex_rspy_lv_epi3.png

If the the python bindings to the Rust-based integrators were installed (see the install guide), additional integrator methods are available. Executing the following runs a Rust based implementation of an exponential Rosenbrock method:

python ex_ormatex_rspy.py -method exprb2_rs

Which gives the result:

_images/ormatex_rspy_lv_exprb2_rs.png