Installation
Requirements
DiRe-JAX has the following dependencies:
Core dependencies (required): jax, numpy, scipy, tqdm, pandas, plotly, loguru, scikit-learn
Utilities dependencies (optional): ripser, persim, fastdtw, fast-twed, pot
JAX Implementation
Important
DiRe-JAX Features
Optimized for small-medium datasets (<50K points)
Large dataset mode for datasets >65K points with automatic memory management
Excellent CPU performance with JIT compilation
GPU acceleration available when JAX is installed with CUDA support
TPU support for cloud-based computation
Mixed precision arithmetic (MPA) support for enhanced performance
Memory-efficient chunking to prevent out-of-memory issues
Optimized kernel caching to minimize recompilation overhead
Ideal for research and development workflows
Installation Options
Basic Installation
pip install dire-jax
With Utilities for Benchmarking
pip install dire-jax[utils]
Complete Installation (with all utilities)
pip install dire-jax[all]
Development Installation
git clone https://github.com/sashakolpakov/dire-jax.git
cd dire-jax
pip install -e .[all]
Hardware Acceleration
JAX GPU Support
For GPU acceleration, JAX needs to be installed with CUDA support:
# For CUDA 12
pip install --upgrade "jax[cuda12]"
# For CUDA 11
pip install --upgrade "jax[cuda11]"
See the JAX GPU installation guide for detailed instructions.
JAX TPU Support
For TPU support on Google Cloud:
pip install --upgrade "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
See the JAX TPU documentation for more information.