Install Guide

Python Only Install

To obtain the latest development version of ORMATEX, clone the repo with:

git clone https://github.com/ORNL/ORMATEX.git

For a local development install, run:

cd ORMATEX
pip install -e .

After running the above, the python unit tests can be executed. From the project base directory (the directory this readme is located in), run:

pytest

Rust and Python Bindings

Download rustup: https://www.rust-lang.org/tools/install

Then install the rust toolchain:

rustup toolchain install stable

After setting up rust and cargo, for an optimized build run:

cargo build --release

To use the rust-based ORMATEX integrators from a python interface, build the python-rust bindings with:

pip install maturin
maturin develop --release

Ensure to use the --release flag for an optimized build. Forgetting this flag will build in debug mode and will result in significantly degraded performance.

Quick Start

Imports

import jax
from jax import numpy as jnp
from ormatex_py.ode_sys import OdeSys, CustomJacLinOp
from ormatex_py import integrate_wrapper

Define the system

class LotkaVolterraAD(OdeSys):
    alpha: float
    beta: float
    delta: float
    gamma: float

    def __init__(self, a=1.0, b=1.0, d=1.0, g=1.0, **kwargs):
        super().__init__()
        self.alpha = a
        self.beta = b
        self.delta = d
        self.gamma = g

    @jax.jit
    def _frhs(self, t, x, **kwargs):
        prey_t = self.alpha * x[0] - self.beta * x[0] * x[1]
        pred_t = self.delta * x[0] * x[1] - self.gamma * x[1]
        return jnp.array([prey_t, pred_t])

Initialize the system and integrate

method = 'epi3'
sys = LotkaVolterraAD()
y0 = jnp.array([0.1, 0.2])
t0 = 0.0
dt = 0.2
nsteps = 100
res = integrate_wrapper.integrate(sys, y0, t0, dt, nsteps, method, max_krylov_dim=4, iom=2)
t_res, y_res = res.t, res.y

Optionally, an explicit Jacobian can be supplied. If not supplied, as above, automatic differentiation will be used.

class LotkaVolterra(OdeSys):
    alpha: float
    beta: float
    delta: float
    gamma: float

    def __init__(self, a=1.0, b=1.0, d=1.0, g=1.0, **kwargs):
        super().__init__()
        self.alpha = a
        self.beta = b
        self.delta = d
        self.gamma = g

    @jax.jit
    def _frhs(self, t, x, **kwargs):
        prey_t = self.alpha * x[0] - self.beta * x[0] * x[1]
        pred_t = self.delta * x[0] * x[1] - self.gamma * x[1]
        return jnp.array([prey_t, pred_t])

    @jax.jit
    def _fjac(self, t, x, **kwargs):
        jac = jnp.array([
            [self.alpha - self.beta * x[1], - self.beta*x[0]],
            [self.delta*x[1], self.delta*x[0] - self.gamma]
            ])
        return CustomJacLinOp(t, x, self.frhs, jac)