Installation
Requirements
DiRe-JAX has the following dependencies:
- Core dependencies (required): jax, numpy, scipy, tqdm, pandas, plotly, loguru, scikit-learn 
- Utilities dependencies (optional): ripser, persim, fastdtw, fast-twed, pot 
JAX Implementation
Important
DiRe-JAX Features
- Optimized for small-medium datasets (<50K points) 
- Large dataset mode for datasets >65K points with automatic memory management 
- Excellent CPU performance with JIT compilation 
- GPU acceleration available when JAX is installed with CUDA support 
- TPU support for cloud-based computation 
- Mixed precision arithmetic (MPA) support for enhanced performance 
- Memory-efficient chunking to prevent out-of-memory issues 
- Optimized kernel caching to minimize recompilation overhead 
- Ideal for research and development workflows 
Installation Options
Basic Installation
pip install dire-jax
With Utilities for Benchmarking
pip install dire-jax[utils]
Complete Installation (with all utilities)
pip install dire-jax[all]
Development Installation
git clone https://github.com/sashakolpakov/dire-jax.git
cd dire-jax
pip install -e .[all]
Hardware Acceleration
JAX GPU Support
For GPU acceleration, JAX needs to be installed with CUDA support:
# For CUDA 12
pip install --upgrade "jax[cuda12]"
# For CUDA 11
pip install --upgrade "jax[cuda11]"
See the JAX GPU installation guide for detailed instructions.
JAX TPU Support
For TPU support on Google Cloud:
pip install --upgrade "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
See the JAX TPU documentation for more information.