API Reference
DiRe-JAX is a high-performance dimensionality reduction library built on JAX, designed for efficient processing of large-scale, high-dimensional datasets.
DiRe JAX Overview
A JAX-based dimensionality reducer.
- class DiRe(dimension=2, n_neighbors=16, init_embedding_type='random', sim_kernel=None, pca_kernel=None, max_iter_layout=128, min_dist=0.01, spread=1.0, cutoff=42.0, n_sample_dirs=8, sample_size=16, neg_ratio=8, my_logger=None, verbose=True, memm=None, mpa=True)[source]
Bases:
object
Dimension Reduction (DiRe) is a class designed to reduce the dimensionality of high-dimensional data using various embedding techniques and optimization algorithms. It supports embedding initialization methods such as random and spectral embeddings and utilizes k-nearest neighbors for manifold learning.
- Parameters:
dimension ((int) Embedding dimension, default 2.)
n_neighbors ((int) Number of nearest neighbors to consider for each point, default 16.)
init_embedding_type ((str)) –
- Method to initialize the embedding; choices are:
’random’ for random projection based on the Johnson-Lindenstrauss lemma;
’spectral’ for spectral embedding (with sim_kernel as similarity kernel);
’pca’ for PCA embedding (classical, no kernel).
By default, ‘random’.
sim_kernel ((callable)) – A similarity kernel function that transforms a distance metric to a similarity score. The function should have the form lambda distance: float -> similarity: float; default None.
max_iter_layout ((int)) – Maximum number of iterations to run the layout optimization, default 128.
min_dist ((float)) – Minimum distance scale for distribution kernel, default 0.01.
spread ((float)) – Spread of the distribution kernel, default 1.0.
cutoff ((float)) – Cutoff for clipping forces during layout optimization, default 42.0.
n_sample_dirs ((int)) – Number of directions to sample in random sampling, default 8.
sample_size ((int)) – Number of samples per direction in random sampling, default 16.
neg_ratio ((int)) – Ratio of negative to positive samples in random sampling, default 8.
my_logger ((logger.Logger or None)) – Custom logger for logging events; if None, a default logger is created, default None.
verbose ((bool)) – Flag to enable verbose output, default True.
- dimension
Target dimensionality of the output space.
- Type:
- n_neighbors
Number of neighbors to consider in the k-nearest neighbors graph.
- Type:
- init_embedding_type
Chosen method for initial embedding.
- Type:
- sim_kernel
Similarity kernel function to be used if ‘init_embedding_type’ is ‘spectral’, by default None.
- Type:
callable
- pca_kernel
Kernel function to be used if ‘init_embedding_type’ is ‘pca’, by default None.
- Type:
callable
- max_iter_layout
Maximum iterations for optimizing the layout.
- Type:
- min_dist
Minimum distance for repulsion used in the distribution kernel.
- Type:
- spread
Spread between the data points used in the distribution kernel.
- Type:
- cutoff
Maximum cutoff for forces during optimization.
- Type:
- n_sample_dirs
Number of random directions sampled.
- Type:
- sample_size
Number of samples per random direction, unless chosen automatically with ‘auto’.
- Type:
int or ‘auto’
- neg_ratio
Ratio of negative to positive samples in the sampling process.
- Type:
- logger
Logger used for logging informational and warning messages.
- Type:
logger.Logger or None
- verbose
Logger output flag (True = output logger messages, False = flush to null)
- Type:
- memm
Memory manager: a dictionary with the batch / memory tile size for different hardware architectures. Accepts ‘tpu’, ‘gpu’ and ‘other’ as keys. Values must be positive integers.
- Type:
dictionary or None
- mpa
Mixed Precision Arithmetic flag (True = use MPA, False = always use float64)
- Type:
- fit_transform(data)[source]
A convenience method that fits the model and then transforms the data. The separate fit and transform methods can only be used one after another because dimension reduction is applied to the dataset as a whole.
- visualize(labels=None, point_size=2)[source]
Visualizes the transformed data, optionally using labels to color the points.
- __init__(dimension=2, n_neighbors=16, init_embedding_type='random', sim_kernel=None, pca_kernel=None, max_iter_layout=128, min_dist=0.01, spread=1.0, cutoff=42.0, n_sample_dirs=8, sample_size=16, neg_ratio=8, my_logger=None, verbose=True, memm=None, mpa=True)[source]
Class constructor
- do_layout(large_dataset_mode=None, force_cpu=False)[source]
Optimize the layout using force-directed placement with JAX kernels.
This method takes the initial embedding and iteratively refines it using attractive and repulsive forces to create a meaningful low-dimensional representation of the high-dimensional data. The algorithm applies:
Attraction forces between points that are neighbors in the high-dimensional space
Repulsion forces between randomly sampled points in the low-dimensional space
Gradual cooling (decreasing force impact) as iterations progress
The final layout is normalized to have zero mean and unit standard deviation.
- Parameters:
large_dataset_mode (bool or None, optional) – If True, use memory-efficient techniques for large datasets. If None, automatically determine based on dataset size.
force_cpu (bool, optional) – If True, force computations on CPU instead of GPU, which can be helpful for large datasets that exceed GPU memory.
- do_pca_embedding()[source]
Initialize embedding using Principal Component Analysis (PCA).
This method creates an initial embedding of the data using PCA, which finds a linear projection of the high-dimensional data into a lower-dimensional space that maximizes the variance. If a kernel is specified, Kernel PCA is used instead, which can capture nonlinear relationships.
Sets the init_embedding attribute with the PCA projection of the data.
- do_rand_sampling(key, arr, n_samples, n_dirs, neg_ratio)[source]
Sample points for force calculation using random projections.
This method implements an efficient sampling strategy to identify points for applying attractive and repulsive forces during layout optimization. It uses random projections to quickly identify nearby points in different directions, and also adds random negative samples for repulsion.
- Parameters:
- Returns:
Array of sampled indices for force calculations
- Return type:
jax.numpy.ndarray
Notes
The sampling strategy works as follows: 1. Generate n_dirs random unit vectors 2. Project the points onto each vector 3. For each point, take the n_samples closest points in each direction 4. Add random negative samples for repulsion 5. Combine all sampled indices
This approach is more efficient than a full nearest neighbor search while still capturing the important local relationships.
- do_random_embedding()[source]
Initialize embedding using Random Projection.
This method creates an initial embedding of the data using random projection, which is a simple and computationally efficient technique for dimensionality reduction. It projects the data onto a randomly generated basis, providing a good starting point for further optimization.
Random projection is supported by the Johnson-Lindenstrauss lemma, which guarantees that the distances between points are approximately preserved under certain conditions.
Sets the init_embedding attribute with the random projection of the data.
- do_spectral_embedding()[source]
Initialize embedding using Spectral Embedding.
This method creates an initial embedding of the data using spectral embedding, which is based on the eigenvectors of the graph Laplacian. It relies on the kNN graph structure to find a lower-dimensional representation that preserves local relationships.
If a similarity kernel is specified, it is applied to transform the distances in the adjacency matrix before computing the Laplacian.
Sets the init_embedding attribute with the spectral embedding of the data.
- find_ab_params(min_dist=0.01, spread=1.0)[source]
Rational function approximation to the probabilistic t-kernel
- fit(data)[source]
Fit the model to data: create the kNN graph and fit the probability kernel to force layout parameters.
- Parameters:
data ((numpy.ndarray)) – High-dimensional data to fit the model. Shape (n_samples, n_features).
- Returns:
self
- Return type:
The DiRe instance fitted to data.
- fit_transform(data)[source]
Fit the model to data and transform it into a low-dimensional layout.
This is a convenience method that combines the fitting and transformation steps. It first builds the kNN graph and then creates the optimized layout in a single operation.
- Parameters:
data (numpy.ndarray) – High-dimensional data to fit and transform. Shape (n_samples, n_features)
- Returns:
The lower-dimensional embedding of the data. Shape (n_samples, dimension)
- Return type:
Notes
This method is more memory-efficient than calling fit() and transform() separately, as it avoids storing intermediate results.
- make_knn_adjacency(batch_size=None)[source]
Internal routine building the adjacency matrix for the kNN graph.
This method computes the k-nearest neighbors for each point in the dataset and constructs a sparse adjacency matrix representing the kNN graph. It attempts to use GPU acceleration if available, with a fallback to CPU.
For large datasets, it uses batching to limit memory usage.
- Parameters:
batch_size (int or None, optional) – Number of samples to process at once. If None, a suitable value will be automatically determined based on dataset size.
attributes (The method sets the following instance)
distances (-)
indices (-)
nearest_neighbor_distances (-)
row_idx (-)
col_idx (Indices for constructing the sparse adjacency matrix)
adjacency (-)
- transform()[source]
Transform the fitted data into a lower-dimensional layout.
This method applies the selected embedding initialization technique to the data that has already been fitted (creating the kNN graph), and then optimizes the layout using force-directed placement.
The transformation process involves:
Creating an initial embedding using the specified method (random projection, PCA, or spectral embedding)
Optimizing the layout with attractive and repulsive forces
- Returns:
The lower-dimensional data embedding with shape (n_samples, dimension). Points are arranged to preserve the local structure of the original data.
- Return type:
- Raises:
ValueError – If an unsupported embedding initialization method is specified.
- visualize(labels=None, point_size=2, title=None, colormap=None, width=800, height=600, opacity=0.7)[source]
Generate an interactive visualization of the data in the transformed space.
This method creates a scatter plot visualization of the embedded data, supporting both 2D and 3D visualizations depending on the specified dimension. Points can be colored by provided labels for clearer visualization of clusters or categories.
- Parameters:
labels (numpy.ndarray or None, optional) – Labels for each data point to color the points in the visualization. If None, all points will have the same color. Default is None.
point_size (int or float, optional) – Size of points in the scatter plot. Default is 2.
title (str or None, optional) – Title for the visualization. If None, a default title will be used. Default is None.
colormap (str or None, optional) – Name of the colormap to use for labels (e.g., ‘viridis’, ‘plasma’). If None, the default Plotly colormap will be used. Default is None.
width (int, optional) – Width of the figure in pixels. Default is 800.
height (int, optional) – Height of the figure in pixels. Default is 600.
opacity (float, optional) – Opacity of the points (0.0 to 1.0). Default is 0.7.
- Returns:
A Plotly figure object if the visualization is successful; None if no layout is available or dimension > 3.
- Return type:
plotly.graph_objs._figure.Figure or None
Notes
For 3D visualizations, you can rotate, zoom, and pan the plot interactively. For both 2D and 3D, hover over points to see their coordinates and labels.
- dimension
Embedding dimension
- n_neighbors
Number of neighbors for kNN computations
- init_embedding_type
Type of the initial embedding (PCA, random, spectral)
- sim_kernel
Similarity kernel
- pca_kernel
PCA kernel
- max_iter_layout
Max iterations for the force layout
- min_dist
Min distance between points in layout
- spread
Layout spread
- cutoff
Cutoff for layout displacement
- n_sample_dirs
Number of sampling directions for layout
- sample_size
Sample size for attraction
- neg_ratio
Ratio for repulsion sample size
- logger
System logger
Core Components
DiRe Class
Provides the main class for dimensionality reduction.
The DiRe (Dimensionality Reduction) class implements a modern approach to dimensionality reduction, leveraging JAX for efficient computation. It uses force-directed layout techniques combined with k-nearest neighbor graph construction to generate meaningful low-dimensional embeddings of high-dimensional data.
- class DiRe(dimension=2, n_neighbors=16, init_embedding_type='random', sim_kernel=None, pca_kernel=None, max_iter_layout=128, min_dist=0.01, spread=1.0, cutoff=42.0, n_sample_dirs=8, sample_size=16, neg_ratio=8, my_logger=None, verbose=True, memm=None, mpa=True)[source]
Bases:
object
Dimension Reduction (DiRe) is a class designed to reduce the dimensionality of high-dimensional data using various embedding techniques and optimization algorithms. It supports embedding initialization methods such as random and spectral embeddings and utilizes k-nearest neighbors for manifold learning.
- Parameters:
dimension ((int) Embedding dimension, default 2.)
n_neighbors ((int) Number of nearest neighbors to consider for each point, default 16.)
init_embedding_type ((str)) –
- Method to initialize the embedding; choices are:
’random’ for random projection based on the Johnson-Lindenstrauss lemma;
’spectral’ for spectral embedding (with sim_kernel as similarity kernel);
’pca’ for PCA embedding (classical, no kernel).
By default, ‘random’.
sim_kernel ((callable)) – A similarity kernel function that transforms a distance metric to a similarity score. The function should have the form lambda distance: float -> similarity: float; default None.
max_iter_layout ((int)) – Maximum number of iterations to run the layout optimization, default 128.
min_dist ((float)) – Minimum distance scale for distribution kernel, default 0.01.
spread ((float)) – Spread of the distribution kernel, default 1.0.
cutoff ((float)) – Cutoff for clipping forces during layout optimization, default 42.0.
n_sample_dirs ((int)) – Number of directions to sample in random sampling, default 8.
sample_size ((int)) – Number of samples per direction in random sampling, default 16.
neg_ratio ((int)) – Ratio of negative to positive samples in random sampling, default 8.
my_logger ((logger.Logger or None)) – Custom logger for logging events; if None, a default logger is created, default None.
verbose ((bool)) – Flag to enable verbose output, default True.
- dimension
Target dimensionality of the output space.
- Type:
- n_neighbors
Number of neighbors to consider in the k-nearest neighbors graph.
- Type:
- init_embedding_type
Chosen method for initial embedding.
- Type:
- sim_kernel
Similarity kernel function to be used if ‘init_embedding_type’ is ‘spectral’, by default None.
- Type:
callable
- pca_kernel
Kernel function to be used if ‘init_embedding_type’ is ‘pca’, by default None.
- Type:
callable
- max_iter_layout
Maximum iterations for optimizing the layout.
- Type:
- min_dist
Minimum distance for repulsion used in the distribution kernel.
- Type:
- spread
Spread between the data points used in the distribution kernel.
- Type:
- cutoff
Maximum cutoff for forces during optimization.
- Type:
- n_sample_dirs
Number of random directions sampled.
- Type:
- sample_size
Number of samples per random direction, unless chosen automatically with ‘auto’.
- Type:
int or ‘auto’
- neg_ratio
Ratio of negative to positive samples in the sampling process.
- Type:
- logger
Logger used for logging informational and warning messages.
- Type:
logger.Logger or None
- verbose
Logger output flag (True = output logger messages, False = flush to null)
- Type:
- memm
Memory manager: a dictionary with the batch / memory tile size for different hardware architectures. Accepts ‘tpu’, ‘gpu’ and ‘other’ as keys. Values must be positive integers.
- Type:
dictionary or None
- mpa
Mixed Precision Arithmetic flag (True = use MPA, False = always use float64)
- Type:
- fit_transform(data)[source]
A convenience method that fits the model and then transforms the data. The separate fit and transform methods can only be used one after another because dimension reduction is applied to the dataset as a whole.
- visualize(labels=None, point_size=2)[source]
Visualizes the transformed data, optionally using labels to color the points.
- __init__(dimension=2, n_neighbors=16, init_embedding_type='random', sim_kernel=None, pca_kernel=None, max_iter_layout=128, min_dist=0.01, spread=1.0, cutoff=42.0, n_sample_dirs=8, sample_size=16, neg_ratio=8, my_logger=None, verbose=True, memm=None, mpa=True)[source]
Class constructor
- dimension
Embedding dimension
- n_neighbors
Number of neighbors for kNN computations
- init_embedding_type
Type of the initial embedding (PCA, random, spectral)
- sim_kernel
Similarity kernel
- pca_kernel
PCA kernel
- max_iter_layout
Max iterations for the force layout
- min_dist
Min distance between points in layout
- spread
Layout spread
- cutoff
Cutoff for layout displacement
- n_sample_dirs
Number of sampling directions for layout
- sample_size
Sample size for attraction
- neg_ratio
Ratio for repulsion sample size
- logger
System logger
- find_ab_params(min_dist=0.01, spread=1.0)[source]
Rational function approximation to the probabilistic t-kernel
- fit(data)[source]
Fit the model to data: create the kNN graph and fit the probability kernel to force layout parameters.
- Parameters:
data ((numpy.ndarray)) – High-dimensional data to fit the model. Shape (n_samples, n_features).
- Returns:
self
- Return type:
The DiRe instance fitted to data.
- transform()[source]
Transform the fitted data into a lower-dimensional layout.
This method applies the selected embedding initialization technique to the data that has already been fitted (creating the kNN graph), and then optimizes the layout using force-directed placement.
The transformation process involves:
Creating an initial embedding using the specified method (random projection, PCA, or spectral embedding)
Optimizing the layout with attractive and repulsive forces
- Returns:
The lower-dimensional data embedding with shape (n_samples, dimension). Points are arranged to preserve the local structure of the original data.
- Return type:
- Raises:
ValueError – If an unsupported embedding initialization method is specified.
- fit_transform(data)[source]
Fit the model to data and transform it into a low-dimensional layout.
This is a convenience method that combines the fitting and transformation steps. It first builds the kNN graph and then creates the optimized layout in a single operation.
- Parameters:
data (numpy.ndarray) – High-dimensional data to fit and transform. Shape (n_samples, n_features)
- Returns:
The lower-dimensional embedding of the data. Shape (n_samples, dimension)
- Return type:
Notes
This method is more memory-efficient than calling fit() and transform() separately, as it avoids storing intermediate results.
- make_knn_adjacency(batch_size=None)[source]
Internal routine building the adjacency matrix for the kNN graph.
This method computes the k-nearest neighbors for each point in the dataset and constructs a sparse adjacency matrix representing the kNN graph. It attempts to use GPU acceleration if available, with a fallback to CPU.
For large datasets, it uses batching to limit memory usage.
- Parameters:
batch_size (int or None, optional) – Number of samples to process at once. If None, a suitable value will be automatically determined based on dataset size.
attributes (The method sets the following instance)
distances (-)
indices (-)
nearest_neighbor_distances (-)
row_idx (-)
col_idx (Indices for constructing the sparse adjacency matrix)
adjacency (-)
- do_pca_embedding()[source]
Initialize embedding using Principal Component Analysis (PCA).
This method creates an initial embedding of the data using PCA, which finds a linear projection of the high-dimensional data into a lower-dimensional space that maximizes the variance. If a kernel is specified, Kernel PCA is used instead, which can capture nonlinear relationships.
Sets the init_embedding attribute with the PCA projection of the data.
- do_spectral_embedding()[source]
Initialize embedding using Spectral Embedding.
This method creates an initial embedding of the data using spectral embedding, which is based on the eigenvectors of the graph Laplacian. It relies on the kNN graph structure to find a lower-dimensional representation that preserves local relationships.
If a similarity kernel is specified, it is applied to transform the distances in the adjacency matrix before computing the Laplacian.
Sets the init_embedding attribute with the spectral embedding of the data.
- do_random_embedding()[source]
Initialize embedding using Random Projection.
This method creates an initial embedding of the data using random projection, which is a simple and computationally efficient technique for dimensionality reduction. It projects the data onto a randomly generated basis, providing a good starting point for further optimization.
Random projection is supported by the Johnson-Lindenstrauss lemma, which guarantees that the distances between points are approximately preserved under certain conditions.
Sets the init_embedding attribute with the random projection of the data.
- do_rand_sampling(key, arr, n_samples, n_dirs, neg_ratio)[source]
Sample points for force calculation using random projections.
This method implements an efficient sampling strategy to identify points for applying attractive and repulsive forces during layout optimization. It uses random projections to quickly identify nearby points in different directions, and also adds random negative samples for repulsion.
- Parameters:
- Returns:
Array of sampled indices for force calculations
- Return type:
jax.numpy.ndarray
Notes
The sampling strategy works as follows: 1. Generate n_dirs random unit vectors 2. Project the points onto each vector 3. For each point, take the n_samples closest points in each direction 4. Add random negative samples for repulsion 5. Combine all sampled indices
This approach is more efficient than a full nearest neighbor search while still capturing the important local relationships.
- do_layout(large_dataset_mode=None, force_cpu=False)[source]
Optimize the layout using force-directed placement with JAX kernels.
This method takes the initial embedding and iteratively refines it using attractive and repulsive forces to create a meaningful low-dimensional representation of the high-dimensional data. The algorithm applies:
Attraction forces between points that are neighbors in the high-dimensional space
Repulsion forces between randomly sampled points in the low-dimensional space
Gradual cooling (decreasing force impact) as iterations progress
The final layout is normalized to have zero mean and unit standard deviation.
- Parameters:
large_dataset_mode (bool or None, optional) – If True, use memory-efficient techniques for large datasets. If None, automatically determine based on dataset size.
force_cpu (bool, optional) – If True, force computations on CPU instead of GPU, which can be helpful for large datasets that exceed GPU memory.
- visualize(labels=None, point_size=2, title=None, colormap=None, width=800, height=600, opacity=0.7)[source]
Generate an interactive visualization of the data in the transformed space.
This method creates a scatter plot visualization of the embedded data, supporting both 2D and 3D visualizations depending on the specified dimension. Points can be colored by provided labels for clearer visualization of clusters or categories.
- Parameters:
labels (numpy.ndarray or None, optional) – Labels for each data point to color the points in the visualization. If None, all points will have the same color. Default is None.
point_size (int or float, optional) – Size of points in the scatter plot. Default is 2.
title (str or None, optional) – Title for the visualization. If None, a default title will be used. Default is None.
colormap (str or None, optional) – Name of the colormap to use for labels (e.g., ‘viridis’, ‘plasma’). If None, the default Plotly colormap will be used. Default is None.
width (int, optional) – Width of the figure in pixels. Default is 800.
height (int, optional) – Height of the figure in pixels. Default is 600.
opacity (float, optional) – Opacity of the points (0.0 to 1.0). Default is 0.7.
- Returns:
A Plotly figure object if the visualization is successful; None if no layout is available or dimension > 3.
- Return type:
plotly.graph_objs._figure.Figure or None
Notes
For 3D visualizations, you can rotate, zoom, and pan the plot interactively. For both 2D and 3D, hover over points to see their coordinates and labels.
- compute_forces_kernel(positions, chunk_indices, neighbor_indices, sample_indices, alpha, a, b)[source]
JAX-optimized kernel for computing attractive and repulsive forces.
- Parameters:
positions (jax.numpy.ndarray) – Current point positions
chunk_indices (jax.numpy.ndarray) – Current batch indices
neighbor_indices (jax.numpy.ndarray) – Indices of neighbors for attractive forces
sample_indices (jax.numpy.ndarray) – Indices of points for repulsive forces
alpha (float) – Cooling factor that scales force magnitude
a (float) – Attraction parameter
b (float) – Repulsion parameter
- Returns:
Net force vectors for each point
- Return type:
jax.numpy.ndarray
- distribution_kernel(dist, a, b)[source]
Probability kernel that maps distances to similarity scores.
This is a rational function approximation of a t-distribution.
- jax_coeff_att(dist, a, b)[source]
JAX-optimized attraction coefficient function.
- jax_coeff_rep(dist, a, b)[source]
JAX-optimized repulsion coefficient function.
- rand_directions(key, dim=2, num=100)[source]
Sample unit vectors in random directions.
Utility Functions
Utilities for plotting, benchmarking, and miscellany
- display_layout(layout, labels, point_size=2)[source]
- Parameters:
layout ((numpy.ndarray) Layout to display, must have dimension 2 or 3.)
labels ((numpy.ndarray) Labels to generate color and legend.)
point_size ((int) Point size for plotting.)
- Returns:
Plot of the layout if the latter has dimension 2 or 3 (using Plotly Express). For higher-dimensional data no plot is provided, and the function returns None.
- Return type:
- do_local_analysis(data, layout, n_neighbors)[source]
- Parameters:
data ((numpy.ndarray) High-dimensional data.)
layout ((numpy.ndarray) Low-dimensional embedding.)
n_neighbors ((int) Number of neighbors in the kNN graph.)
- Returns:
None
- Return type:
Prints out the local metrics (embedding stress, neighborhood preservation score, etc.)
- visualize_persistence_diagram(diagram, dimension, title)[source]
Create a visualization of a persistence diagram.
- Parameters:
- Returns:
Interactive visualization of the persistence diagram
- Return type:
plotly.graph_objects.Figure
- do_persistence_analysis(data, layout, dimension, subsample_threshold, rng_key, n_steps=100)[source]
Perform a comprehensive persistence analysis by subsampling data, computing persistence diagrams, and calculating distances between Betti curves of high-dimensional and low-dimensional data. This analysis includes computing distances such as Dynamic Time Warp (DTW), Time Warp Edit Distance (TWED), and Earth Mover Distance.
- Parameters:
data (numpy.ndarray) – High-dimensional data.
layout (numpy.ndarray) – Low-dimensional embedding.
dimension (int) – The dimension up to which persistence diagrams are computed.
subsample_threshold (float) – The threshold used for subsampling the data points.
rng_key (jax.random.PRNGKey) – Random key used for generating random numbers for subsampling, ensuring reproducibility.
n_steps (int, optional) – The number of steps or points in the filtration range for computing Betti curves, default 100.
- Returns:
This function primarily visualizes results and prints metric values.
- Return type:
None
- do_context_analysis(data, layout, labels, subsample_threshold, n_neighbors, rng_key, **kwargs)[source]
- Parameters:
data ((numpy.ndarray) High-dimensional data.)
layout ((numpy.ndarray) Low-dimensional embedding.)
labels ((numpy.ndarray) Data labels, required for context preservation analysis.)
subsample_threshold ((float) Subsample thresholds.)
n_neighbors ((int) Number of nearest neighbors for the kNN graph of data.)
rng_key (Random key used for generating random numbers for subsampling, ensuring reproducibility.)
kwargs (Keyword arguments for kNN and SVM score procedure, and similar.)
- Returns:
None
- Return type:
This function prints out context preservation measures.
- block_timing()[source]
- Returns:
float
- Return type:
elapsed runtime (in seconds) for a given block of code
- viz_benchmark(reducer, data, **kwargs)[source]
Run a benchmarking process for dimensionality reduction using provided reducer.
- Parameters:
reducer ((object) Dimensionality reduction model with a fit_transform method.) – It should also have an n_neighbors attribute for computing neighborhood scores.
data ((numpy.ndarray) High-dimensional data to be reduced.)
kwargs (Keyword arguments for benchmark’s metrics (such as labels if using labeled data, maximum dimension for) – persistence homology computation, threshold subsample_threshold for subsampling, etc.)
- Returns:
`None` – conducts persistence analysis, prints the embedding stress and neighborhood preservation score, and times the embedding process.
- Return type:
This function does not return anything. It performs the embedding, displays the layout,
- do_metrics(reducer, data, **kwargs)[source]
Compute local and global metrics, and context preservation measures.
- Parameters:
reducer ((object) The dimensionality reduction model with a fit_transform method.) – It should also have an n_neighbors attribute for computing neighborhood scores.
data ((numpy.ndarray) The high-dimensional data to be reduced.)
kwargs (Keyword arguments to be passed to compute_local_metrics, compute_global_metrics,) – and compute_context_measures.
- Returns:
dict
- Return type:
A dictionary of local and global metrics, and context preservation measures.
- run_benchmark(reducer, data, *, num_trials=100, only_stats=True, **kwargs)[source]
Benchmark a reducer on given data.
- Parameters:
reducer ((object) The dimensionality reduction model with a fit_transform method.) – It should also have an n_neighbors attribute for computing neighborhood scores.
data ((numpy.ndarray) The high-dimensional data to be reduced as a benchmark.)
num_trials ((int) The number of runs to collect stats.)
only_stats ((bool) If True, only the tuple (mean, std) for each metrics is returned.) – If False, both stats and values for all runs are returned.
kwargs (Keyword arguments to be passed to do_metrics.)
- Returns:
dict (If only_stats is True, a dictionary with stats of all metrics available.)
dict, dict (If only_stats is False, a dictionary with stats and a dictionary with the initial values of all metrics.)
Kernalized kNN index
A JAX-based implementation for efficient k-nearest neighbors.
- class HPIndex[source]
Bases:
object
A kernelized kNN index that uses batching / tiling to efficiently handle large datasets with limited memory usage.
- static knn_tiled(x, y, k=5, x_tile_size=8192, y_batch_size=1024, dtype=<class 'jax.numpy.float64'>)[source]
Advanced implementation that tiles both database and query points. This wrapper handles the dynamic aspects before calling the JIT-compiled function.
- Parameters:
x – (n, d) array of database points
y – (m, d) array of query points
k – number of nearest neighbors
x_tile_size – size of database tiles
y_batch_size – size of query batches
dtype – desired floating-point dtype (e.g., jnp.float32 or jnp.float64)
- Returns:
(m, k) array of indices of nearest neighbors
Performance Metrics
Auxiliary functions for high-performance benchmarking metrics
- welford_update(carry, new_value, finite_threshold=1000000000000.0)[source]
Update running mean and variance using Welford’s algorithm, ignoring values beyond the given finite_threshold.
- welford_finalize(agg)[source]
Finalize the computation of mean and variance from the aggregate statistics.
- Parameters:
agg ((tuple) A tuple containing the aggregated statistics:) –
count: (int) The total count of valid (non-NaN) entries.
mean: (float) The computed mean of the dataset.
M2: (float) The computed sum of squares of differences from the mean.
- Returns:
tuple
- Return type:
A tuple containing the final mean and standard deviation of the dataset.
- welford(data)[source]
Compute the mean and standard deviation of a dataset using Welford’s algorithm.
- Parameters:
data ((jax.numpy.ndarray) An array of data points, potentially containing NaNs which are ignored.)
- Returns:
tuple
- Return type:
A tuple containing the mean and standard deviation of the valid entries in the dataset.
- make_knn_graph(data, n_neighbors, batch_size=None)[source]
Compute the distances to nearest neighbors and their indices in the kNN graph of data.
- Parameters:
data (numpy.ndarray) – High-dimensional data points.
n_neighbors (int) – Number of nearest neighbors to find for each point.
batch_size (int or None, optional) – Number of samples to process at once. If None, a suitable value will be automatically determined based on dataset size.
- Returns:
Tuple containing: - distances: Array of shape (n_samples, n_neighbors+1) with distances to nearest neighbors - indices: Array of shape (n_samples, n_neighbors+1) with indices of nearest neighbors
The first column contains each point’s self-reference (distance 0.0 and own index). The remaining columns contain the n_neighbors nearest neighbors in ascending order of distance.
- Return type:
- compute_stress(data, layout, n_neighbors, eps=1e-06)[source]
Compute the stress of an embedding based on the distances in the original high-dimensional space and the embedded space, using a ratio of distances.
- Parameters:
data ((numpy.ndarray) High-dimensional data points.)
layout ((numpy.ndarray) Embedded data points.)
n_neighbors ((int) Number of nearest neighbors to consider for each point.)
eps ((float) Parameter to prevent zero division if mean distortion is near zero, default 1e-6.)
- Returns:
float
- Return type:
The normalized stress value indicating the quality of the embedding.
- compute_neighbor_score(data, layout, n_neighbors)[source]
Computes the neighborhood preservation score between high-dimensional data and its corresponding low-dimensional layout.
The function evaluates how well the neighborhood relationships are preserved when data is projected from a high-dimensional space to a lower-dimensional space using the K-nearest neighbors approach. This involves comparing the nearest neighbors in the original space with those in the reduced space.
- Parameters:
data ((numpy.ndarray) A NumPy array of shape (n_samples, data_dim) containing) – the original high-dimensional data.
layout ((numpy.ndarray) A NumPy array of shape (n_samples, embed_dim) containing) – the lower-dimensional embedding of the data.
n_neighbors ((int) The number of nearest neighbors to consider for each data point.)
- Returns:
list –
neighbor_mean: (float) The mean of the neighborhood preservation scores.
neighbor_std: (float) The standard deviation of the neighborhood preservation scores.
- Return type:
A list containing two floats:
- compute_local_metrics(data, layout, n_neighbors, memory_efficient=None)[source]
Compute local metrics of the (data, layout) pair.
- Parameters:
data ((numpy.ndarray) High-dimensional data points.)
layout ((numpy.ndarray) Low-dimensional data points corresponding to the high-dimensional data.)
n_neighbors ((int) Number of closest neighbors for the kNN graph.)
memory_efficient ((bool or None) If True, use memory-efficient algorithms for large datasets.) – If None, automatically determine based on dataset size.
- Returns:
dict
- Return type:
A dictionary containing computed scores of each type (stress, neighborhood preservation).
- threshold_subsample(*arrays, threshold, rng_key)[source]
Subsample multiple arrays based on a specified threshold. The function generates random numbers and selects the samples where the random number is less than the threshold.
- Parameters:
*arrays ((tuple of numpy.ndarray)) – The input data arrays to be subsampled. Each array should have the same number of samples (rows).
threshold ((float)) – Probability threshold for subsampling; only samples with generated random numbers below this value are kept.
rng_key (Random key or random generator used for generating random numbers, ensuring reproducibility.)
- Returns:
tuple
- Return type:
A tuple containing the subsampled arrays in the same order as the input arrays.
- diagrams(data, layout, max_dim, subsample_threshold, rng_key)[source]
Generate persistence diagrams for high-dimensional and low-dimensional data up to a specified dimension, after subsampling both datasets based on a threshold. The subsampling is performed to reduce the dataset size and potentially highlight more relevant features when computing topological summaries.
- Parameters:
data ((numpy.ndarray) High-dimensional data points.)
layout ((numpy.ndarray) Low-dimensional data points corresponding to the high-dimensional data.)
max_dim ((int) Maximum dimension of homology groups to compute.)
subsample_threshold ((float) Threshold used for subsampling the data points.)
rng_key (Random key used for generating random numbers for subsampling, ensuring reproducibility.)
- Returns:
dict – for the respective high-dimensional and low-dimensional datasets.
- Return type:
A dictionary containing two keys, ‘data’ and ‘layout’, each associated with arrays of persistence diagrams
- betti_curve(diagram, n_steps=100)[source]
Compute the Betti curve from a persistence diagram, which is a function of the number of features that persist at different filtration levels. This curve provides a summary of topological features across scales.
- Parameters:
- Returns:
tuple –
The first array represents the evenly spaced filtration values.
The second array represents the Betti numbers at each filtration value.
- Return type:
A tuple of two numpy arrays:
- compute_dtw(axis_x_hd, axis_y_hd, axis_x_ld, axis_y_ld, norm_factor=1.0)[source]
Compute normalized Dynamic Time Warp (DTW) distance (using Euclidean metric) between two Betti curves represented as time series with time dimension x and values y.
- Parameters:
axis_x_hd ((numpy.ndarray) Time axis of the high-dimensional Betti curve.)
axis_y_hd ((numpy.ndarray) Values of the high-dimensional Betti curve.)
axis_x_ld ((numpy.ndarray) Time axis of the low-dimensional Betti curve.)
axis_y_ld ((numpy.ndarray) Values of the low-dimensional Betti curve.)
norm_factor ((float) Normalization factor, default 1.0.)
- Returns:
float
- Return type:
Normalized DTW distance between two Betti curves.
- compute_twed(axis_x_hd, axis_y_hd, axis_x_ld, axis_y_ld, norm_factor=1.0)[source]
Compute normalized Time Warp Edit Distance (TWED) distance using Euclidean metric between two Betti curves represented as time series with time dimension x and values y.
- Parameters:
axis_x_hd ((numpy.ndarray) Time axis of the high-dimensional Betti curve.)
axis_y_hd ((numpy.ndarray) Values of the high-dimensional Betti curve.)
axis_x_ld ((numpy.ndarray) Time axis of the low-dimensional Betti curve.)
axis_y_ld ((numpy.ndarray) Values of the low-dimensional Betti curve.)
norm_factor ((float) Normalization factor, default 1.0.)
- Returns:
float
- Return type:
Normalized TWED distance between two Betti curves.
- compute_emd(axis_x_hd, axis_y_hd, axis_x_ld, axis_y_ld, adjust_mass=False, norm_factor=1.0)[source]
Compute normalized Earth Mover Distance (EMD) distance (using Euclidean metric) between two Betti curves represented as time series with time dimension x and values y.
- Parameters:
axis_x_hd ((numpy.ndarray) Time axis of the high-dimensional Betti curve.)
axis_y_hd ((numpy.ndarray) Values of the high-dimensional Betti curve.)
axis_x_ld ((numpy.ndarray) Time axis of the low-dimensional Betti curve.)
axis_y_ld ((numpy.ndarray) Values of the low-dimensional Betti curve.)
adjust_mass ((bool) Use to adjust mass (by default, EMD is computed for unit mass curves);) – default False.
norm_factor ((float) Normalization factor, default 1.0.)
- Returns:
float
- Return type:
Normalized EMD distance between two Betti curves.
- compute_wasserstein(diag_hd, diag_ld, norm_factor=1.0)[source]
Compute normalized Wasserstein distance between two persistence diagrams (usually one of high-dimensional data and one of low-dimensional data).
- Parameters:
- Returns:
float
- Return type:
Normalized Wasserstein distance between persistence diagrams.
- compute_bottleneck(diag_hd, diag_ld, norm_factor=1.0)[source]
Compute normalized bottleneck distance between two persistence diagrams (usually one of high-dimensional data and one of low-dimensional data).
- Parameters:
- Returns:
float
- Return type:
Normalized bottleneck distance between persistence diagrams.
- compute_global_metrics(data, layout, dimension, subsample_threshold, rng_key, n_steps=100, metrics_only=True)[source]
Compute and compare persistence metrics between high-dimensional and low-dimensional data representations. The function calculates the Dynamic Time Warp (DTW), Time Warp Edit Distance (TWED), and Earth Mover Distance based on Betti curves derived from persistence diagrams. The function also calculate the Wasserstein distance and the bottleneck distance based on persistence diagrams. This evaluation helps quantify the topological fidelity of dimensionality reduction.
- Parameters:
data ((numpy.ndarray) High-dimensional data points.)
layout ((numpy.ndarray) Low-dimensional data points corresponding to the high-dimensional data.)
dimension ((int) The maximum dimension for which to compute persistence diagrams.)
subsample_threshold ((float) Threshold used for subsampling the data.)
rng_key (Random key used for generating random numbers for subsampling, ensuring reproducibility.)
n_steps ((int, optional) The number of steps or points in the filtration range for computing Betti curves.)
metrics_only ((bool) If True, return metrics only; otherwise diagrams and Betti curves are also returned;) – default True.
- Returns:
If metrics_only is True – dict(dict): A dictionary containing one item ‘metrics’ that is a dictionary of lists of computed distances for each of the metrics (DTW, TWED, EMD, Wasserstein, and bottleneck). Each list is populated according to the dimensions in which the distances were computed.
If metrics_only is False – dict(dict, dict, dict): A dictionary containing three items: - ‘metrics’: A dictionary of metrics, as described above; - ‘diags’: A dictionary of diagrams for the initial data and for the layout; - ‘bettis’: A dictionary of Betti curves for the initial data and for the layout. Each dictionary is a dictionary of lists. Each list is populated according to the dimensions in which the distances, diagrams, or curves were computed.
- compute_svm_accuracy(X, y, test_size=0.3, reg_param=1.0, max_iter=100, random_state=42)[source]
Compute linear SVM classifier accuracy for given labelled data X with labels y.
- Parameters:
X ((numpy.ndarray) Data.)
y ((numpy.ndarray) Data labels.)
test_size ((float) Test size (between 0.0 and 1.0) for the train / test split, default 0.3.)
reg_param ((float) Regularization parameter for SVM, default 1.0.)
max_iter ((int) Maximal number of iterations for SVM training, default 100.)
random_state ((int) Random state for reproducibility, default 42.)
- Returns:
float
- Return type:
Accuracy of the linear SVM model on the test set.
- compute_svm_score(data, layout, labels, subsample_threshold, rng_key, **kwargs)[source]
Compute SVM score (context preservation measure) by comparing linear SVM classifier accuracies on the high-dimensional data and on the low-dimensional embedding.
- Parameters:
data ((numpy.ndarray) High-dimensional data.)
layout ((numpy.ndarray) Low-dimensional embedding.)
labels ((numpy.ndarray) Data labels.)
subsample_threshold ((float) Threshold used for subsampling the data.)
rng_key (Random key used for generating random numbers for subsampling, ensuring reproducibility.)
kwargs (Other keyword arguments used by the various scores above.)
- Returns:
float
- Return type:
SVM context preservation score.
- compute_knn_accuracy(X, y, n_neighbors=16, test_size=0.3, random_state=42)[source]
Compute kNN classifier accuracy for given labelled data X with labels y.
- Parameters:
X ((numpy.ndarray) Data.)
y ((numpy.ndarray) Data labels.)
test_size ((float) Test size (between 0.0 and 1.0) for the train / test split, default 0.3.)
n_neighbors ((int) Number of neighbors for kNN classification, default 16.)
random_state ((int) Random state for reproducibility, default 42.)
- Returns:
accuracy
- Return type:
(float) Accuracy of the KNN model on the test set.
- compute_knn_score(data, layout, labels, n_neighbors=16, **kwargs)[source]
Compute kNN score (context preservation measure) by comparing kNN classifier accuracies on the high-dimensional data and on the low-dimensional embedding.
- Parameters:
data ((numpy.ndarray) High-dimensional data.)
layout ((numpy.ndarray) Low-dimensional embedding.)
labels ((numpy.ndarray) Data labels.)
n_neighbors ((int) Number of nearest neighbors for kNN classifier, default 16.)
kwargs (Other keyword arguments used by the various scores above.)
- Returns:
float
- Return type:
kNN context preservation score.
- compute_quality_measures(data, layout, n_neighbors=None)[source]
Compute quality measures for assessing the quality of dimensionality reduction.
This function calculates various metrics that evaluate how well the low-dimensional representation preserves important properties of the high-dimensional data.
- Parameters:
data (numpy.ndarray) – High-dimensional data points.
layout (numpy.ndarray) – Low-dimensional embedding of the data.
- Returns:
Dictionary of quality measures including: - trustworthiness: Measures if points that are close in the embedding are also close in original space - continuity: Measures if points that are close in original space are also close in the embedding - shepard_correlation: Correlation between pairwise distances in original and embedded spaces
- Return type:
- compute_context_measures(data, layout, labels, subsample_threshold, n_neighbors, rng_key, **kwargs)[source]
Compute measures of how well the embedding preserves the context of the data.
- Parameters:
data (numpy.ndarray) – High-dimensional data points.
layout (numpy.ndarray) – Low-dimensional embedding of the data.
labels (numpy.ndarray) – Data labels needed for context preservation analysis.
subsample_threshold (float) – Threshold used for subsampling the data.
n_neighbors (int) – Number of neighbors for the kNN graph.
rng_key (jax.random.PRNGKey) – Random key for reproducible subsampling.
**kwargs – Additional keyword arguments for the scoring functions.
- Returns:
Dictionary of context preservation measures, including SVM and kNN classification performance comparisons.
- Return type: