Welcome to DiRe-JAX’s documentation!

View PDF Open in Colab

DiRe-JAX is a high-performance dimensionality reduction package built with JAX for efficient computation on CPUs, GPUs, and TPUs.

Quick Start

Installation

Basic installation:

pip install dire-jax

With utilities for benchmarking and metrics:

pip install dire-jax[utils]

Complete installation (all utilities):

pip install dire-jax[all]

Example Usage

from dire_jax import DiRe
from sklearn.datasets import make_blobs

# Create sample data
features, labels = make_blobs(
    n_samples=10000,
    n_features=100,
    centers=5,
    random_state=42
)

# Initialize reducer
reducer = DiRe(
    n_components=2,
    n_neighbors=16,
    max_iter_layout=32
)

# Fit and transform
embedding = reducer.fit_transform(features)

# Visualize results
reducer.visualize(labels=labels, point_size=4)

Key Features

  • JAX-powered: Leverages JAX for JIT compilation and automatic differentiation

  • Hardware acceleration: Supports CPU, GPU (via CUDA), and TPU

  • Scalable: Optimized for datasets up to 50K points, with large dataset mode for >65K points

  • Memory-efficient: Advanced chunking and memory management for large datasets

  • High-performance: Mixed precision arithmetic (MPA) and optimized kernel caching

  • Research-friendly: Clean, modular design for experimentation

  • Benchmarking tools: Built-in utilities for performance evaluation

Indices and tables