Welcome to DiRe-JAX’s documentation!
DiRe-JAX is a high-performance dimensionality reduction package built with JAX for efficient computation on CPUs, GPUs, and TPUs.
Quick Start
Installation
Basic installation:
pip install dire-jax
With utilities for benchmarking and metrics:
pip install dire-jax[utils]
Complete installation (all utilities):
pip install dire-jax[all]
Example Usage
from dire_jax import DiRe
from sklearn.datasets import make_blobs
# Create sample data
features, labels = make_blobs(
n_samples=10000,
n_features=100,
centers=5,
random_state=42
)
# Initialize reducer
reducer = DiRe(
n_components=2,
n_neighbors=16,
max_iter_layout=32
)
# Fit and transform
embedding = reducer.fit_transform(features)
# Visualize results
reducer.visualize(labels=labels, point_size=4)
Key Features
JAX-powered: Leverages JAX for JIT compilation and automatic differentiation
Hardware acceleration: Supports CPU, GPU (via CUDA), and TPU
Scalable: Optimized for datasets up to 50K points, with large dataset mode for >65K points
Memory-efficient: Advanced chunking and memory management for large datasets
High-performance: Mixed precision arithmetic (MPA) and optimized kernel caching
Research-friendly: Clean, modular design for experimentation
Benchmarking tools: Built-in utilities for performance evaluation