Minor Bug: `linear_util` deprecation in newer `jax` version leads to import failure of `ODE_*_dt.py`
Hi there! I was just creating the basis for some sphinx
docs in a freshly setup environment and stumbled upon the following problem: The linear_utils
that is used in the ODE_*_dt
modules has been removed in newer jax
versions, meaning most functions in these modules don't work.
Steps to reproduce
- Install the current jax version and package in a new
venv
python -m venv .venv source .venv/bin/activate pip install -r requirements_cpu.txt
- Start a python session and try to import one of the faulty modules
>>> from adoptODE import ODE_Exp_dt Traceback (most recent call last): File "<stdin>", line 1, in <module> File "C:\Users\kottl\Projects\2024\adoptODE\adoptODE\ODE_Exp_dt.py", line 41, in <module> from jax import linear_util as lu ImportError: cannot import name 'linear_util' from 'jax' (C:\Users\kottl\Projects\2024\adoptODE\.venv\Lib\site-packages\jax\__init__.py)
Environment
I got this problem both on Windows 10 with python 3.11.9
and and on the GWDG hpc frontents (some linux distro) with python 3.11.6
. Both times, I was using jax 0.4.30
(cpu version).
Source of error
jax.linear_util was deprecated in JAX v0.4.16 and removed in JAX v0.4.24.
Implications
As long as the user doesn't need anything from the ODE_*_dt
modules (actually my use case, I only discovered this because sphinx
complained) there is no problem.
In the examples, only the RayleighBernard
notebook is affected.
Possible solutions
Short term
restrict jax version
to jax<0.4.24
Long term
Replace linear_util
with something else