Installation

Requirements

DiRe-JAX has the following dependencies:

  • Core dependencies (required): jax, numpy, scipy, tqdm, pandas, plotly, loguru, scikit-learn

  • Utilities dependencies (optional): ripser, persim, fastdtw, fast-twed, pot

JAX Implementation

Important

DiRe-JAX Features

  • Optimized for small-medium datasets (<50K points)

  • Large dataset mode for datasets >65K points with automatic memory management

  • Excellent CPU performance with JIT compilation

  • GPU acceleration available when JAX is installed with CUDA support

  • TPU support for cloud-based computation

  • Mixed precision arithmetic (MPA) support for enhanced performance

  • Memory-efficient chunking to prevent out-of-memory issues

  • Optimized kernel caching to minimize recompilation overhead

  • Ideal for research and development workflows

Installation Options

Basic Installation

pip install dire-jax

With Utilities for Benchmarking

pip install dire-jax[utils]

Complete Installation (with all utilities)

pip install dire-jax[all]

Development Installation

git clone https://github.com/sashakolpakov/dire-jax.git
cd dire-jax
pip install -e .[all]

Hardware Acceleration

JAX GPU Support

For GPU acceleration, JAX needs to be installed with CUDA support:

# For CUDA 12
pip install --upgrade "jax[cuda12]"

# For CUDA 11
pip install --upgrade "jax[cuda11]"

See the JAX GPU installation guide for detailed instructions.

JAX TPU Support

For TPU support on Google Cloud:

pip install --upgrade "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

See the JAX TPU documentation for more information.