Usage Guide

Basic Usage

DiRe-JAX offers fast dimensionality reduction with JAX-based computation.

Quick Start

from dire_jax import DiRe
import numpy as np

# Create some sample data
data = np.random.random((1000, 50))  # 1000 samples in 50 dimensions

# Initialize DiRe with desired parameters
reducer = DiRe(
    n_components=2,         # Target dimension
    n_neighbors=15,         # Number of neighbors to consider
    init='pca',             # Initialization method
    max_iter_layout=128,    # Maximum number of layout iterations
    verbose=True            # Show progress
)

# Fit and transform the data
embedding = reducer.fit_transform(data)

# Visualize the results
reducer.visualize()

Performance Characteristics

DiRe-JAX is optimized for:

  • Small to medium datasets (<50K points)

  • Large dataset support (>65K points) with automatic memory management

  • Fully vectorized computation with JIT compilation

  • Excellent CPU performance

  • GPU acceleration when JAX is installed with CUDA support

  • TPU support for cloud-based computation

  • Mixed precision arithmetic (MPA) for enhanced performance on modern hardware

  • Memory-efficient chunking to handle large datasets without memory issues

  • Optimized kernel caching to minimize recompilation and improve runtime

Advanced Configuration

DiRe offers several parameters that can be tuned to optimize the dimensionality reduction process:

  • n_components: Target dimension for the embedding (typically 2 or 3)

  • n_neighbors: Number of neighbors to consider when constructing the graph

  • init: Method to initialize the embedding (‘pca’, ‘random’, ‘spectral’)

  • max_iter_layout: Maximum number of iterations for the layout algorithm

  • min_dist: Minimum distance between points in the embedding

  • spread: Controls how spread out the embedding is

  • cutoff: Maximum distance for neighbor connections

  • n_sample_dirs: Number of sample directions for the layout algorithm

  • sample_size: Sample size for the layout algorithm

  • neg_ratio: Ratio of negative to positive samples

  • batch_size: Number of samples to process at once (None for automatic sizing)

  • mpa: Enable Mixed Precision Arithmetic for improved performance (default: True)

  • memm: Memory manager dictionary for different hardware architectures

Example with Custom Parameters

from dire_jax import DiRe
from sklearn.datasets import make_blobs

# Create dataset with clusters
features, labels = make_blobs(
    n_samples=10000,
    n_features=100,
    centers=5,
    random_state=42
)

# Initialize with custom parameters
reducer = DiRe(
    n_components=2,
    n_neighbors=30,          # More neighbors for global structure
    init='spectral',         # Spectral initialization
    max_iter_layout=256,     # More iterations for convergence
    min_dist=0.01,           # Tighter packing
    spread=2.0,              # More spread out embedding
    verbose=True
)

# Fit and transform
embedding = reducer.fit_transform(features)

# Visualize with labels
reducer.visualize(labels=labels, point_size=3)

Large Dataset Example

For very large datasets, DiRe-JAX automatically switches to memory-efficient mode:

from dire_jax import DiRe
import numpy as np

# Create a large dataset
large_data = np.random.random((100000, 200))  # 100K samples, 200 dimensions

# DiRe will automatically use large dataset mode
reducer = DiRe(
    n_components=2,
    n_neighbors=16,
    batch_size=4096,        # Custom batch size for memory control
    mpa=True,               # Enable mixed precision arithmetic
    max_iter_layout=64,     # Fewer iterations for faster processing
    verbose=True            # Monitor progress
)

# Memory-efficient processing with automatic chunking
embedding = reducer.fit_transform(large_data)

# Visualize the result
reducer.visualize(point_size=1)

Benchmarking

If you’ve installed DiRe-JAX with the [utils] extra, you can use the benchmarking utilities:

from dire_jax import DiRe
from dire_jax.dire_utils import run_benchmark, viz_benchmark
from sklearn.datasets import make_blobs
from jax import random

# Create data
features, labels = make_blobs(
    n_samples=10000,
    n_features=100,
    centers=5,
    random_state=42
)

# Initialize reducer
reducer = DiRe(n_components=2, n_neighbors=15)

# Run the benchmark
benchmark_results = run_benchmark(
    reducer,
    features,
    labels=labels,
    dimension=1,  # for persistence homology
    subsample_threshold=0.1,  # subsample for speed
    rng_key=random.PRNGKey(42),
    num_trials=1,  # choose sample size
    only_stats=True,
)

# Print the results
print(benchmark_results)

# Or visualize the benchmark
viz_benchmark(
    reducer,
    features,
    labels=labels,
    dimension=1,  # for persistence homology
    subsample_threshold=0.1,  # subsample for speed
    rng_key=random.PRNGKey(42),
    point_size=2
)

Working with Different Data Types

DiRe-JAX works with various data formats:

import numpy as np
import pandas as pd
from dire_jax import DiRe

# NumPy arrays
data_numpy = np.random.random((1000, 50))

# Pandas DataFrames
data_df = pd.DataFrame(data_numpy)

# Both work seamlessly
reducer = DiRe(n_components=2)
embedding_numpy = reducer.fit_transform(data_numpy)
embedding_df = reducer.fit_transform(data_df.values)