Installation

Requirements

DiRe-JAX has two sets of dependencies:

  • Core dependencies (required): jax, numpy, scipy, tqdm, pandas, plotly, loguru, scikit-learn

  • Utilities dependencies (optional): ripser, persim, fastdtw, fast-twed, pot

Important

JAX GPU/TPU Support

For GPU or TPU acceleration, JAX needs to be specifically installed with hardware support. The default JAX installation through pip doesn’t include GPU/TPU support.

To enable GPU/TPU acceleration follow the JAX installation instructions <https://github.com/google/jax#installation>

Installing JAX with hardware acceleration can significantly improve the performance of DiRe-JAX, especially for larger datasets.

Installation Options

Standard Installation

To install the main DiRe class only:

pip install dire-jax

Installation with Utilities

To install DiRe-JAX with additional utilities for benchmarking and metrics:

pip install dire-jax[utils]

Development Installation

To install for development:

git clone https://github.com/sashakolpakov/dire-jax.git
cd dire-jax
pip install -e .[utils]

After installation, you may need to install JAX with GPU/TPU support separately as described above.