Usage Guide
Basic Usage
DiRe-JAX provides a high-performance dimensionality reduction tool based on JAX. Here’s a quick example of how to use it:
from dire_jax import DiRe
import numpy as np
# Create some sample data
data = np.random.random((1000, 50)) # 1000 samples in 50 dimensions
# Initialize DiRe with desired parameters
reducer = DiRe(
dimension=2, # Target dimension
n_neighbors=15, # Number of neighbors to consider
init_embedding_type='pca', # Initialization method
max_iter_layout=128, # Maximum number of layout iterations
verbose=True # Show progress
)
# Fit and transform the data
embedding = reducer.fit_transform(data)
# Visualize the results
reducer.visualize()
Advanced Configuration
DiRe offers several parameters that can be tuned to optimize the dimensionality reduction process:
dimension: Target dimension for the embedding (typically 2 or 3)
n_neighbors: Number of neighbors to consider when constructing the graph
init_embedding_type: Method to initialize the embedding (‘pca’, ‘random’)
max_iter_layout: Maximum number of iterations for the layout algorithm
min_dist: Minimum distance between points in the embedding
spread: Controls how spread out the embedding is
cutoff: Maximum distance for neighbor connections
n_sample_dirs: Number of sample directions for the layout algorithm
sample_size: Sample size for the layout algorithm
neg_ratio: Ratio of negative to positive samples
Benchmarking
If you’ve installed DiRe-JAX with the [utils] extra, you can use the benchmarking utilities:
from dire_jax import DiRe
from dire_jax.dire_utils import run_benchmark, viz_benchmark
from sklearn.datasets import make_blobs
from jax import random
# Create data
features, labels = make_blobs(n_samples=10000, n_features=100, centers=5, random_state=42)
# Initialize reducer
reducer = DiRe(dimension=2, n_neighbors=15)
# Then either run the benchmark ...
benchmark_results = run_benchmark(reducer,
features,
labels=labels,
dimension=1, # for persistence homology
subsample_threshold=0.1, # subsample for speed
rng_key=random.PRNGKey(42),
num_trials=1, # choose sample size
only_stats=True,)
# and print the results ...
print(benchmark_results)
# ... or visualize the benchmark ...
viz_benchmark(reducer,
features,
labels=labels,
dimension=1, # for persistence homology
subsample_threshold=0.1, # subsample for speed
rng_key=random.PRNGKey(42),
point_size=2)