Welcome to DiRe-JAX’s documentation!

View PDF Open in Colab

DiRe-JAX is a new dimensionality reduction package written in JAX, offering high-performance dimensionality reduction with efficient computation.

Quick Start

Installation

Install the main DiRe class only:

pip install dire-jax

If you also need benchmarking utilities:

pip install dire-jax[utils]

Example Usage

from dire_jax import DiRe
from sklearn.datasets import make_blobs

n_samples  = 100_000
n_features = 1_000
n_centers  = 12
features_blobs, labels_blobs = make_blobs(n_samples=n_samples, n_features=n_features, centers=n_centers, random_state=42)

reducer_blobs = DiRe(dimension=2,
                     n_neighbors=16,
                     init_embedding_type='pca',
                     max_iter_layout=32,
                     min_dist=1e-4,
                     spread=1.0,
                     cutoff=4.0,
                     n_sample_dirs=8,
                     sample_size=16,
                     neg_ratio=32,
                     verbose=False,)

_ = reducer_blobs.fit_transform(features_blobs)
reducer_blobs.visualize(labels=labels_blobs, point_size=4)

Indices and tables