Install Guide¶
Python Only Install¶
To obtain the latest development version of ORMATEX, clone the repo with:
git clone https://github.com/ORNL/ORMATEX.git
For a local development install, run:
cd ORMATEX
pip install -e .
After running the above, the python unit tests can be executed. From the project base directory (the directory this readme is located in), run:
pytest
Rust and Python Bindings¶
Download rustup: https://www.rust-lang.org/tools/install
Then install the rust toolchain:
rustup toolchain install stable
After setting up rust and cargo, for an optimized build run:
cargo build --release
To use the rust-based ORMATEX integrators from a python interface, build the python-rust bindings with:
pip install maturin
maturin develop --release
Ensure to use the --release flag for an optimized build. Forgetting this flag will build in debug mode and will result in significantly degraded performance.
Quick Start¶
Imports
import jax
from jax import numpy as jnp
from ormatex_py.ode_sys import OdeSys, CustomJacLinOp
from ormatex_py import integrate_wrapper
Define the system
class LotkaVolterraAD(OdeSys):
alpha: float
beta: float
delta: float
gamma: float
def __init__(self, a=1.0, b=1.0, d=1.0, g=1.0, **kwargs):
super().__init__()
self.alpha = a
self.beta = b
self.delta = d
self.gamma = g
@jax.jit
def _frhs(self, t, x, **kwargs):
prey_t = self.alpha * x[0] - self.beta * x[0] * x[1]
pred_t = self.delta * x[0] * x[1] - self.gamma * x[1]
return jnp.array([prey_t, pred_t])
Initialize the system and integrate
method = 'epi3'
sys = LotkaVolterraAD()
y0 = jnp.array([0.1, 0.2])
t0 = 0.0
dt = 0.2
nsteps = 100
res = integrate_wrapper.integrate(sys, y0, t0, dt, nsteps, method, max_krylov_dim=4, iom=2)
t_res, y_res = res.t, res.y
Optionally, an explicit Jacobian can be supplied. If not supplied, as above, automatic differentiation will be used.
class LotkaVolterra(OdeSys):
alpha: float
beta: float
delta: float
gamma: float
def __init__(self, a=1.0, b=1.0, d=1.0, g=1.0, **kwargs):
super().__init__()
self.alpha = a
self.beta = b
self.delta = d
self.gamma = g
@jax.jit
def _frhs(self, t, x, **kwargs):
prey_t = self.alpha * x[0] - self.beta * x[0] * x[1]
pred_t = self.delta * x[0] * x[1] - self.gamma * x[1]
return jnp.array([prey_t, pred_t])
@jax.jit
def _fjac(self, t, x, **kwargs):
jac = jnp.array([
[self.alpha - self.beta * x[1], - self.beta*x[0]],
[self.delta*x[1], self.delta*x[0] - self.gamma]
])
return CustomJacLinOp(t, x, self.frhs, jac)