API Reference

DiRe-JAX is a high-performance dimensionality reduction library built on JAX, designed for efficient processing of large-scale, high-dimensional datasets.

DiRe JAX Overview

A JAX-based dimensionality reducer.

class DiRe(dimension=2, n_neighbors=16, init_embedding_type='random', sim_kernel=None, pca_kernel=None, max_iter_layout=128, min_dist=0.01, spread=1.0, cutoff=42.0, n_sample_dirs=8, sample_size=16, neg_ratio=8, my_logger=None, verbose=True, memm=None, mpa=True)[source]

Bases: object

Dimension Reduction (DiRe) is a class designed to reduce the dimensionality of high-dimensional data using various embedding techniques and optimization algorithms. It supports embedding initialization methods such as random and spectral embeddings and utilizes k-nearest neighbors for manifold learning.

Parameters:
  • dimension ((int) Embedding dimension, default 2.)

  • n_neighbors ((int) Number of nearest neighbors to consider for each point, default 16.)

  • init_embedding_type ((str)) –

    Method to initialize the embedding; choices are:
    • ’random’ for random projection based on the Johnson-Lindenstrauss lemma;

    • ’spectral’ for spectral embedding (with sim_kernel as similarity kernel);

    • ’pca’ for PCA embedding (classical, no kernel).

    By default, ‘random’.

  • sim_kernel ((callable)) – A similarity kernel function that transforms a distance metric to a similarity score. The function should have the form lambda distance: float -> similarity: float; default None.

  • max_iter_layout ((int)) – Maximum number of iterations to run the layout optimization, default 128.

  • min_dist ((float)) – Minimum distance scale for distribution kernel, default 0.01.

  • spread ((float)) – Spread of the distribution kernel, default 1.0.

  • cutoff ((float)) – Cutoff for clipping forces during layout optimization, default 42.0.

  • n_sample_dirs ((int)) – Number of directions to sample in random sampling, default 8.

  • sample_size ((int)) – Number of samples per direction in random sampling, default 16.

  • neg_ratio ((int)) – Ratio of negative to positive samples in random sampling, default 8.

  • my_logger ((logger.Logger or None)) – Custom logger for logging events; if None, a default logger is created, default None.

  • verbose ((bool)) – Flag to enable verbose output, default True.

dimension

Target dimensionality of the output space.

Type:

int

n_neighbors

Number of neighbors to consider in the k-nearest neighbors graph.

Type:

int

init_embedding_type

Chosen method for initial embedding.

Type:

str

sim_kernel

Similarity kernel function to be used if ‘init_embedding_type’ is ‘spectral’, by default None.

Type:

callable

pca_kernel

Kernel function to be used if ‘init_embedding_type’ is ‘pca’, by default None.

Type:

callable

max_iter_layout

Maximum iterations for optimizing the layout.

Type:

int

min_dist

Minimum distance for repulsion used in the distribution kernel.

Type:

float

spread

Spread between the data points used in the distribution kernel.

Type:

float

cutoff

Maximum cutoff for forces during optimization.

Type:

float

n_sample_dirs

Number of random directions sampled.

Type:

int

sample_size

Number of samples per random direction, unless chosen automatically with ‘auto’.

Type:

int or ‘auto’

neg_ratio

Ratio of negative to positive samples in the sampling process.

Type:

int

logger

Logger used for logging informational and warning messages.

Type:

logger.Logger or None

verbose

Logger output flag (True = output logger messages, False = flush to null)

Type:

bool

memm

Memory manager: a dictionary with the batch / memory tile size for different hardware architectures. Accepts ‘tpu’, ‘gpu’ and ‘other’ as keys. Values must be positive integers.

Type:

dictionary or None

mpa

Mixed Precision Arithmetic flag (True = use MPA, False = always use float64)

Type:

bool

fit_transform(data)[source]

A convenience method that fits the model and then transforms the data. The separate fit and transform methods can only be used one after another because dimension reduction is applied to the dataset as a whole.

visualize(labels=None, point_size=2)[source]

Visualizes the transformed data, optionally using labels to color the points.

__init__(dimension=2, n_neighbors=16, init_embedding_type='random', sim_kernel=None, pca_kernel=None, max_iter_layout=128, min_dist=0.01, spread=1.0, cutoff=42.0, n_sample_dirs=8, sample_size=16, neg_ratio=8, my_logger=None, verbose=True, memm=None, mpa=True)[source]

Class constructor

do_layout(large_dataset_mode=None, force_cpu=False)[source]

Optimize the layout using force-directed placement with JAX kernels.

This method takes the initial embedding and iteratively refines it using attractive and repulsive forces to create a meaningful low-dimensional representation of the high-dimensional data. The algorithm applies:

  1. Attraction forces between points that are neighbors in the high-dimensional space

  2. Repulsion forces between randomly sampled points in the low-dimensional space

  3. Gradual cooling (decreasing force impact) as iterations progress

The final layout is normalized to have zero mean and unit standard deviation.

Parameters:
  • large_dataset_mode (bool or None, optional) – If True, use memory-efficient techniques for large datasets. If None, automatically determine based on dataset size.

  • force_cpu (bool, optional) – If True, force computations on CPU instead of GPU, which can be helpful for large datasets that exceed GPU memory.

do_pca_embedding()[source]

Initialize embedding using Principal Component Analysis (PCA).

This method creates an initial embedding of the data using PCA, which finds a linear projection of the high-dimensional data into a lower-dimensional space that maximizes the variance. If a kernel is specified, Kernel PCA is used instead, which can capture nonlinear relationships.

Sets the init_embedding attribute with the PCA projection of the data.

do_rand_sampling(key, arr, n_samples, n_dirs, neg_ratio)[source]

Sample points for force calculation using random projections.

This method implements an efficient sampling strategy to identify points for applying attractive and repulsive forces during layout optimization. It uses random projections to quickly identify nearby points in different directions, and also adds random negative samples for repulsion.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key

  • arr (jax.numpy.ndarray) – Array of current point positions

  • n_samples (int) – Number of samples to take in each direction

  • n_dirs (int) – Number of random directions to sample

  • neg_ratio (int) – Ratio of negative samples to positive samples

Returns:

Array of sampled indices for force calculations

Return type:

jax.numpy.ndarray

Notes

The sampling strategy works as follows: 1. Generate n_dirs random unit vectors 2. Project the points onto each vector 3. For each point, take the n_samples closest points in each direction 4. Add random negative samples for repulsion 5. Combine all sampled indices

This approach is more efficient than a full nearest neighbor search while still capturing the important local relationships.

do_random_embedding()[source]

Initialize embedding using Random Projection.

This method creates an initial embedding of the data using random projection, which is a simple and computationally efficient technique for dimensionality reduction. It projects the data onto a randomly generated basis, providing a good starting point for further optimization.

Random projection is supported by the Johnson-Lindenstrauss lemma, which guarantees that the distances between points are approximately preserved under certain conditions.

Sets the init_embedding attribute with the random projection of the data.

do_spectral_embedding()[source]

Initialize embedding using Spectral Embedding.

This method creates an initial embedding of the data using spectral embedding, which is based on the eigenvectors of the graph Laplacian. It relies on the kNN graph structure to find a lower-dimensional representation that preserves local relationships.

If a similarity kernel is specified, it is applied to transform the distances in the adjacency matrix before computing the Laplacian.

Sets the init_embedding attribute with the spectral embedding of the data.

find_ab_params(min_dist=0.01, spread=1.0)[source]

Rational function approximation to the probabilistic t-kernel

fit(data)[source]

Fit the model to data: create the kNN graph and fit the probability kernel to force layout parameters.

Parameters:

data ((numpy.ndarray)) – High-dimensional data to fit the model. Shape (n_samples, n_features).

Returns:

self

Return type:

The DiRe instance fitted to data.

fit_transform(data)[source]

Fit the model to data and transform it into a low-dimensional layout.

This is a convenience method that combines the fitting and transformation steps. It first builds the kNN graph and then creates the optimized layout in a single operation.

Parameters:

data (numpy.ndarray) – High-dimensional data to fit and transform. Shape (n_samples, n_features)

Returns:

The lower-dimensional embedding of the data. Shape (n_samples, dimension)

Return type:

numpy.ndarray

Notes

This method is more memory-efficient than calling fit() and transform() separately, as it avoids storing intermediate results.

make_knn_adjacency(batch_size=None)[source]

Internal routine building the adjacency matrix for the kNN graph.

This method computes the k-nearest neighbors for each point in the dataset and constructs a sparse adjacency matrix representing the kNN graph. It attempts to use GPU acceleration if available, with a fallback to CPU.

For large datasets, it uses batching to limit memory usage.

Parameters:
  • batch_size (int or None, optional) – Number of samples to process at once. If None, a suitable value will be automatically determined based on dataset size.

  • attributes (The method sets the following instance)

  • distances (-)

  • indices (-)

  • nearest_neighbor_distances (-)

  • row_idx (-)

  • col_idx (Indices for constructing the sparse adjacency matrix)

  • adjacency (-)

transform()[source]

Transform the fitted data into a lower-dimensional layout.

This method applies the selected embedding initialization technique to the data that has already been fitted (creating the kNN graph), and then optimizes the layout using force-directed placement.

The transformation process involves:

  1. Creating an initial embedding using the specified method (random projection, PCA, or spectral embedding)

  2. Optimizing the layout with attractive and repulsive forces

Returns:

The lower-dimensional data embedding with shape (n_samples, dimension). Points are arranged to preserve the local structure of the original data.

Return type:

numpy.ndarray

Raises:

ValueError – If an unsupported embedding initialization method is specified.

visualize(labels=None, point_size=2, title=None, colormap=None, width=800, height=600, opacity=0.7)[source]

Generate an interactive visualization of the data in the transformed space.

This method creates a scatter plot visualization of the embedded data, supporting both 2D and 3D visualizations depending on the specified dimension. Points can be colored by provided labels for clearer visualization of clusters or categories.

Parameters:
  • labels (numpy.ndarray or None, optional) – Labels for each data point to color the points in the visualization. If None, all points will have the same color. Default is None.

  • point_size (int or float, optional) – Size of points in the scatter plot. Default is 2.

  • title (str or None, optional) – Title for the visualization. If None, a default title will be used. Default is None.

  • colormap (str or None, optional) – Name of the colormap to use for labels (e.g., ‘viridis’, ‘plasma’). If None, the default Plotly colormap will be used. Default is None.

  • width (int, optional) – Width of the figure in pixels. Default is 800.

  • height (int, optional) – Height of the figure in pixels. Default is 600.

  • opacity (float, optional) – Opacity of the points (0.0 to 1.0). Default is 0.7.

Returns:

A Plotly figure object if the visualization is successful; None if no layout is available or dimension > 3.

Return type:

plotly.graph_objs._figure.Figure or None

Notes

For 3D visualizations, you can rotate, zoom, and pan the plot interactively. For both 2D and 3D, hover over points to see their coordinates and labels.

dimension

Embedding dimension

n_neighbors

Number of neighbors for kNN computations

init_embedding_type

Type of the initial embedding (PCA, random, spectral)

sim_kernel

Similarity kernel

pca_kernel

PCA kernel

max_iter_layout

Max iterations for the force layout

min_dist

Min distance between points in layout

spread

Layout spread

cutoff

Cutoff for layout displacement

n_sample_dirs

Number of sampling directions for layout

sample_size

Sample size for attraction

neg_ratio

Ratio for repulsion sample size

logger

System logger

Core Components

DiRe Class

Provides the main class for dimensionality reduction.

The DiRe (Dimensionality Reduction) class implements a modern approach to dimensionality reduction, leveraging JAX for efficient computation. It uses force-directed layout techniques combined with k-nearest neighbor graph construction to generate meaningful low-dimensional embeddings of high-dimensional data.

class DiRe(dimension=2, n_neighbors=16, init_embedding_type='random', sim_kernel=None, pca_kernel=None, max_iter_layout=128, min_dist=0.01, spread=1.0, cutoff=42.0, n_sample_dirs=8, sample_size=16, neg_ratio=8, my_logger=None, verbose=True, memm=None, mpa=True)[source]

Bases: object

Dimension Reduction (DiRe) is a class designed to reduce the dimensionality of high-dimensional data using various embedding techniques and optimization algorithms. It supports embedding initialization methods such as random and spectral embeddings and utilizes k-nearest neighbors for manifold learning.

Parameters:
  • dimension ((int) Embedding dimension, default 2.)

  • n_neighbors ((int) Number of nearest neighbors to consider for each point, default 16.)

  • init_embedding_type ((str)) –

    Method to initialize the embedding; choices are:
    • ’random’ for random projection based on the Johnson-Lindenstrauss lemma;

    • ’spectral’ for spectral embedding (with sim_kernel as similarity kernel);

    • ’pca’ for PCA embedding (classical, no kernel).

    By default, ‘random’.

  • sim_kernel ((callable)) – A similarity kernel function that transforms a distance metric to a similarity score. The function should have the form lambda distance: float -> similarity: float; default None.

  • max_iter_layout ((int)) – Maximum number of iterations to run the layout optimization, default 128.

  • min_dist ((float)) – Minimum distance scale for distribution kernel, default 0.01.

  • spread ((float)) – Spread of the distribution kernel, default 1.0.

  • cutoff ((float)) – Cutoff for clipping forces during layout optimization, default 42.0.

  • n_sample_dirs ((int)) – Number of directions to sample in random sampling, default 8.

  • sample_size ((int)) – Number of samples per direction in random sampling, default 16.

  • neg_ratio ((int)) – Ratio of negative to positive samples in random sampling, default 8.

  • my_logger ((logger.Logger or None)) – Custom logger for logging events; if None, a default logger is created, default None.

  • verbose ((bool)) – Flag to enable verbose output, default True.

dimension

Target dimensionality of the output space.

Type:

int

n_neighbors

Number of neighbors to consider in the k-nearest neighbors graph.

Type:

int

init_embedding_type

Chosen method for initial embedding.

Type:

str

sim_kernel

Similarity kernel function to be used if ‘init_embedding_type’ is ‘spectral’, by default None.

Type:

callable

pca_kernel

Kernel function to be used if ‘init_embedding_type’ is ‘pca’, by default None.

Type:

callable

max_iter_layout

Maximum iterations for optimizing the layout.

Type:

int

min_dist

Minimum distance for repulsion used in the distribution kernel.

Type:

float

spread

Spread between the data points used in the distribution kernel.

Type:

float

cutoff

Maximum cutoff for forces during optimization.

Type:

float

n_sample_dirs

Number of random directions sampled.

Type:

int

sample_size

Number of samples per random direction, unless chosen automatically with ‘auto’.

Type:

int or ‘auto’

neg_ratio

Ratio of negative to positive samples in the sampling process.

Type:

int

logger

Logger used for logging informational and warning messages.

Type:

logger.Logger or None

verbose

Logger output flag (True = output logger messages, False = flush to null)

Type:

bool

memm

Memory manager: a dictionary with the batch / memory tile size for different hardware architectures. Accepts ‘tpu’, ‘gpu’ and ‘other’ as keys. Values must be positive integers.

Type:

dictionary or None

mpa

Mixed Precision Arithmetic flag (True = use MPA, False = always use float64)

Type:

bool

fit_transform(data)[source]

A convenience method that fits the model and then transforms the data. The separate fit and transform methods can only be used one after another because dimension reduction is applied to the dataset as a whole.

visualize(labels=None, point_size=2)[source]

Visualizes the transformed data, optionally using labels to color the points.

__init__(dimension=2, n_neighbors=16, init_embedding_type='random', sim_kernel=None, pca_kernel=None, max_iter_layout=128, min_dist=0.01, spread=1.0, cutoff=42.0, n_sample_dirs=8, sample_size=16, neg_ratio=8, my_logger=None, verbose=True, memm=None, mpa=True)[source]

Class constructor

dimension

Embedding dimension

n_neighbors

Number of neighbors for kNN computations

init_embedding_type

Type of the initial embedding (PCA, random, spectral)

sim_kernel

Similarity kernel

pca_kernel

PCA kernel

max_iter_layout

Max iterations for the force layout

min_dist

Min distance between points in layout

spread

Layout spread

cutoff

Cutoff for layout displacement

n_sample_dirs

Number of sampling directions for layout

sample_size

Sample size for attraction

neg_ratio

Ratio for repulsion sample size

logger

System logger

find_ab_params(min_dist=0.01, spread=1.0)[source]

Rational function approximation to the probabilistic t-kernel

fit(data)[source]

Fit the model to data: create the kNN graph and fit the probability kernel to force layout parameters.

Parameters:

data ((numpy.ndarray)) – High-dimensional data to fit the model. Shape (n_samples, n_features).

Returns:

self

Return type:

The DiRe instance fitted to data.

transform()[source]

Transform the fitted data into a lower-dimensional layout.

This method applies the selected embedding initialization technique to the data that has already been fitted (creating the kNN graph), and then optimizes the layout using force-directed placement.

The transformation process involves:

  1. Creating an initial embedding using the specified method (random projection, PCA, or spectral embedding)

  2. Optimizing the layout with attractive and repulsive forces

Returns:

The lower-dimensional data embedding with shape (n_samples, dimension). Points are arranged to preserve the local structure of the original data.

Return type:

numpy.ndarray

Raises:

ValueError – If an unsupported embedding initialization method is specified.

fit_transform(data)[source]

Fit the model to data and transform it into a low-dimensional layout.

This is a convenience method that combines the fitting and transformation steps. It first builds the kNN graph and then creates the optimized layout in a single operation.

Parameters:

data (numpy.ndarray) – High-dimensional data to fit and transform. Shape (n_samples, n_features)

Returns:

The lower-dimensional embedding of the data. Shape (n_samples, dimension)

Return type:

numpy.ndarray

Notes

This method is more memory-efficient than calling fit() and transform() separately, as it avoids storing intermediate results.

make_knn_adjacency(batch_size=None)[source]

Internal routine building the adjacency matrix for the kNN graph.

This method computes the k-nearest neighbors for each point in the dataset and constructs a sparse adjacency matrix representing the kNN graph. It attempts to use GPU acceleration if available, with a fallback to CPU.

For large datasets, it uses batching to limit memory usage.

Parameters:
  • batch_size (int or None, optional) – Number of samples to process at once. If None, a suitable value will be automatically determined based on dataset size.

  • attributes (The method sets the following instance)

  • distances (-)

  • indices (-)

  • nearest_neighbor_distances (-)

  • row_idx (-)

  • col_idx (Indices for constructing the sparse adjacency matrix)

  • adjacency (-)

do_pca_embedding()[source]

Initialize embedding using Principal Component Analysis (PCA).

This method creates an initial embedding of the data using PCA, which finds a linear projection of the high-dimensional data into a lower-dimensional space that maximizes the variance. If a kernel is specified, Kernel PCA is used instead, which can capture nonlinear relationships.

Sets the init_embedding attribute with the PCA projection of the data.

do_spectral_embedding()[source]

Initialize embedding using Spectral Embedding.

This method creates an initial embedding of the data using spectral embedding, which is based on the eigenvectors of the graph Laplacian. It relies on the kNN graph structure to find a lower-dimensional representation that preserves local relationships.

If a similarity kernel is specified, it is applied to transform the distances in the adjacency matrix before computing the Laplacian.

Sets the init_embedding attribute with the spectral embedding of the data.

do_random_embedding()[source]

Initialize embedding using Random Projection.

This method creates an initial embedding of the data using random projection, which is a simple and computationally efficient technique for dimensionality reduction. It projects the data onto a randomly generated basis, providing a good starting point for further optimization.

Random projection is supported by the Johnson-Lindenstrauss lemma, which guarantees that the distances between points are approximately preserved under certain conditions.

Sets the init_embedding attribute with the random projection of the data.

do_rand_sampling(key, arr, n_samples, n_dirs, neg_ratio)[source]

Sample points for force calculation using random projections.

This method implements an efficient sampling strategy to identify points for applying attractive and repulsive forces during layout optimization. It uses random projections to quickly identify nearby points in different directions, and also adds random negative samples for repulsion.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key

  • arr (jax.numpy.ndarray) – Array of current point positions

  • n_samples (int) – Number of samples to take in each direction

  • n_dirs (int) – Number of random directions to sample

  • neg_ratio (int) – Ratio of negative samples to positive samples

Returns:

Array of sampled indices for force calculations

Return type:

jax.numpy.ndarray

Notes

The sampling strategy works as follows: 1. Generate n_dirs random unit vectors 2. Project the points onto each vector 3. For each point, take the n_samples closest points in each direction 4. Add random negative samples for repulsion 5. Combine all sampled indices

This approach is more efficient than a full nearest neighbor search while still capturing the important local relationships.

do_layout(large_dataset_mode=None, force_cpu=False)[source]

Optimize the layout using force-directed placement with JAX kernels.

This method takes the initial embedding and iteratively refines it using attractive and repulsive forces to create a meaningful low-dimensional representation of the high-dimensional data. The algorithm applies:

  1. Attraction forces between points that are neighbors in the high-dimensional space

  2. Repulsion forces between randomly sampled points in the low-dimensional space

  3. Gradual cooling (decreasing force impact) as iterations progress

The final layout is normalized to have zero mean and unit standard deviation.

Parameters:
  • large_dataset_mode (bool or None, optional) – If True, use memory-efficient techniques for large datasets. If None, automatically determine based on dataset size.

  • force_cpu (bool, optional) – If True, force computations on CPU instead of GPU, which can be helpful for large datasets that exceed GPU memory.

visualize(labels=None, point_size=2, title=None, colormap=None, width=800, height=600, opacity=0.7)[source]

Generate an interactive visualization of the data in the transformed space.

This method creates a scatter plot visualization of the embedded data, supporting both 2D and 3D visualizations depending on the specified dimension. Points can be colored by provided labels for clearer visualization of clusters or categories.

Parameters:
  • labels (numpy.ndarray or None, optional) – Labels for each data point to color the points in the visualization. If None, all points will have the same color. Default is None.

  • point_size (int or float, optional) – Size of points in the scatter plot. Default is 2.

  • title (str or None, optional) – Title for the visualization. If None, a default title will be used. Default is None.

  • colormap (str or None, optional) – Name of the colormap to use for labels (e.g., ‘viridis’, ‘plasma’). If None, the default Plotly colormap will be used. Default is None.

  • width (int, optional) – Width of the figure in pixels. Default is 800.

  • height (int, optional) – Height of the figure in pixels. Default is 600.

  • opacity (float, optional) – Opacity of the points (0.0 to 1.0). Default is 0.7.

Returns:

A Plotly figure object if the visualization is successful; None if no layout is available or dimension > 3.

Return type:

plotly.graph_objs._figure.Figure or None

Notes

For 3D visualizations, you can rotate, zoom, and pan the plot interactively. For both 2D and 3D, hover over points to see their coordinates and labels.

compute_forces_kernel(positions, chunk_indices, neighbor_indices, sample_indices, alpha, a, b)[source]

JAX-optimized kernel for computing attractive and repulsive forces.

Parameters:
  • positions (jax.numpy.ndarray) – Current point positions

  • chunk_indices (jax.numpy.ndarray) – Current batch indices

  • neighbor_indices (jax.numpy.ndarray) – Indices of neighbors for attractive forces

  • sample_indices (jax.numpy.ndarray) – Indices of points for repulsive forces

  • alpha (float) – Cooling factor that scales force magnitude

  • a (float) – Attraction parameter

  • b (float) – Repulsion parameter

Returns:

Net force vectors for each point

Return type:

jax.numpy.ndarray

distribution_kernel(dist, a, b)[source]

Probability kernel that maps distances to similarity scores.

This is a rational function approximation of a t-distribution.

Parameters:
  • dist (jax.numpy.ndarray) – Distance values

  • a (float) – Scale parameter that controls the steepness of the distribution

  • b (float) – Shape parameter that controls the tail behavior

Returns:

Similarity score(s) between 0 and 1

Return type:

float or jax.numpy.ndarray

jax_coeff_att(dist, a, b)[source]

JAX-optimized attraction coefficient function.

jax_coeff_rep(dist, a, b)[source]

JAX-optimized repulsion coefficient function.

rand_directions(key, dim=2, num=100)[source]

Sample unit vectors in random directions.

Parameters:
  • key (jax.random.PRNGKey) – Random number generator key

  • dim (int) – Dimensionality of the vectors

  • num (int) – Number of random directions to sample

Returns:

Array of shape (num, dim) containing unit vectors

Return type:

jax.numpy.ndarray

get_slice(arr, k, i)[source]

Extract a slice of size k centered around index i.

Parameters:
  • arr (jax.numpy.ndarray) – Input array

  • k (int) – Size of the slice

  • i (int) – Center index position

Returns:

Slice of the input array

Return type:

jax.numpy.ndarray

Utility Functions

Utilities for plotting, benchmarking, and miscellany

display_layout(layout, labels, point_size=2)[source]
Parameters:
  • layout ((numpy.ndarray) Layout to display, must have dimension 2 or 3.)

  • labels ((numpy.ndarray) Labels to generate color and legend.)

  • point_size ((int) Point size for plotting.)

Returns:

Plot of the layout if the latter has dimension 2 or 3 (using Plotly Express). For higher-dimensional data no plot is provided, and the function returns None.

Return type:

object

do_local_analysis(data, layout, n_neighbors)[source]
Parameters:
  • data ((numpy.ndarray) High-dimensional data.)

  • layout ((numpy.ndarray) Low-dimensional embedding.)

  • n_neighbors ((int) Number of neighbors in the kNN graph.)

Returns:

None

Return type:

Prints out the local metrics (embedding stress, neighborhood preservation score, etc.)

visualize_persistence_diagram(diagram, dimension, title)[source]

Create a visualization of a persistence diagram.

Parameters:
  • diagram (list of tuples) – Persistence diagram represented as a list of (birth, death) tuples

  • dimension (int) – Homology dimension of the diagram

  • title (str) – Title for the visualization

Returns:

Interactive visualization of the persistence diagram

Return type:

plotly.graph_objects.Figure

do_persistence_analysis(data, layout, dimension, subsample_threshold, rng_key, n_steps=100)[source]

Perform a comprehensive persistence analysis by subsampling data, computing persistence diagrams, and calculating distances between Betti curves of high-dimensional and low-dimensional data. This analysis includes computing distances such as Dynamic Time Warp (DTW), Time Warp Edit Distance (TWED), and Earth Mover Distance.

Parameters:
  • data (numpy.ndarray) – High-dimensional data.

  • layout (numpy.ndarray) – Low-dimensional embedding.

  • dimension (int) – The dimension up to which persistence diagrams are computed.

  • subsample_threshold (float) – The threshold used for subsampling the data points.

  • rng_key (jax.random.PRNGKey) – Random key used for generating random numbers for subsampling, ensuring reproducibility.

  • n_steps (int, optional) – The number of steps or points in the filtration range for computing Betti curves, default 100.

Returns:

This function primarily visualizes results and prints metric values.

Return type:

None

do_context_analysis(data, layout, labels, subsample_threshold, n_neighbors, rng_key, **kwargs)[source]
Parameters:
  • data ((numpy.ndarray) High-dimensional data.)

  • layout ((numpy.ndarray) Low-dimensional embedding.)

  • labels ((numpy.ndarray) Data labels, required for context preservation analysis.)

  • subsample_threshold ((float) Subsample thresholds.)

  • n_neighbors ((int) Number of nearest neighbors for the kNN graph of data.)

  • rng_key (Random key used for generating random numbers for subsampling, ensuring reproducibility.)

  • kwargs (Keyword arguments for kNN and SVM score procedure, and similar.)

Returns:

None

Return type:

This function prints out context preservation measures.

block_timing()[source]
Returns:

float

Return type:

elapsed runtime (in seconds) for a given block of code

viz_benchmark(reducer, data, **kwargs)[source]

Run a benchmarking process for dimensionality reduction using provided reducer.

Parameters:
  • reducer ((object) Dimensionality reduction model with a fit_transform method.) – It should also have an n_neighbors attribute for computing neighborhood scores.

  • data ((numpy.ndarray) High-dimensional data to be reduced.)

  • kwargs (Keyword arguments for benchmark’s metrics (such as labels if using labeled data, maximum dimension for) – persistence homology computation, threshold subsample_threshold for subsampling, etc.)

Returns:

`None` – conducts persistence analysis, prints the embedding stress and neighborhood preservation score, and times the embedding process.

Return type:

This function does not return anything. It performs the embedding, displays the layout,

do_metrics(reducer, data, **kwargs)[source]

Compute local and global metrics, and context preservation measures.

Parameters:
  • reducer ((object) The dimensionality reduction model with a fit_transform method.) – It should also have an n_neighbors attribute for computing neighborhood scores.

  • data ((numpy.ndarray) The high-dimensional data to be reduced.)

  • kwargs (Keyword arguments to be passed to compute_local_metrics, compute_global_metrics,) – and compute_context_measures.

Returns:

dict

Return type:

A dictionary of local and global metrics, and context preservation measures.

run_benchmark(reducer, data, *, num_trials=100, only_stats=True, **kwargs)[source]

Benchmark a reducer on given data.

Parameters:
  • reducer ((object) The dimensionality reduction model with a fit_transform method.) – It should also have an n_neighbors attribute for computing neighborhood scores.

  • data ((numpy.ndarray) The high-dimensional data to be reduced as a benchmark.)

  • num_trials ((int) The number of runs to collect stats.)

  • only_stats ((bool) If True, only the tuple (mean, std) for each metrics is returned.) – If False, both stats and values for all runs are returned.

  • kwargs (Keyword arguments to be passed to do_metrics.)

Returns:

  • dict (If only_stats is True, a dictionary with stats of all metrics available.)

  • dict, dict (If only_stats is False, a dictionary with stats and a dictionary with the initial values of all metrics.)

Kernalized kNN index

A JAX-based implementation for efficient k-nearest neighbors.

class HPIndex[source]

Bases: object

A kernelized kNN index that uses batching / tiling to efficiently handle large datasets with limited memory usage.

static knn_tiled(x, y, k=5, x_tile_size=8192, y_batch_size=1024, dtype=<class 'jax.numpy.float64'>)[source]

Advanced implementation that tiles both database and query points. This wrapper handles the dynamic aspects before calling the JIT-compiled function.

Parameters:
  • x – (n, d) array of database points

  • y – (m, d) array of query points

  • k – number of nearest neighbors

  • x_tile_size – size of database tiles

  • y_batch_size – size of query batches

  • dtype – desired floating-point dtype (e.g., jnp.float32 or jnp.float64)

Returns:

(m, k) array of indices of nearest neighbors

Performance Metrics

Auxiliary functions for high-performance benchmarking metrics

welford_update(carry, new_value, finite_threshold=1000000000000.0)[source]

Update running mean and variance using Welford’s algorithm, ignoring values beyond the given finite_threshold.

Parameters:
  • carry (tuple(int, float, float)) – (count, mean, M2) — intermediate stats.

  • new_value (float) – Incoming value to incorporate.

  • finite_threshold (float) – Max magnitude allowed for inclusion.

Returns:

(carry, None)

Return type:

Updated carry and dummy output for lax.scan.

welford_finalize(agg)[source]

Finalize the computation of mean and variance from the aggregate statistics.

Parameters:

agg ((tuple) A tuple containing the aggregated statistics:) –

  • count: (int) The total count of valid (non-NaN) entries.

  • mean: (float) The computed mean of the dataset.

  • M2: (float) The computed sum of squares of differences from the mean.

Returns:

tuple

Return type:

A tuple containing the final mean and standard deviation of the dataset.

welford(data)[source]

Compute the mean and standard deviation of a dataset using Welford’s algorithm.

Parameters:

data ((jax.numpy.ndarray) An array of data points, potentially containing NaNs which are ignored.)

Returns:

tuple

Return type:

A tuple containing the mean and standard deviation of the valid entries in the dataset.

make_knn_graph(data, n_neighbors, batch_size=None)[source]

Compute the distances to nearest neighbors and their indices in the kNN graph of data.

Parameters:
  • data (numpy.ndarray) – High-dimensional data points.

  • n_neighbors (int) – Number of nearest neighbors to find for each point.

  • batch_size (int or None, optional) – Number of samples to process at once. If None, a suitable value will be automatically determined based on dataset size.

Returns:

Tuple containing: - distances: Array of shape (n_samples, n_neighbors+1) with distances to nearest neighbors - indices: Array of shape (n_samples, n_neighbors+1) with indices of nearest neighbors

The first column contains each point’s self-reference (distance 0.0 and own index). The remaining columns contain the n_neighbors nearest neighbors in ascending order of distance.

Return type:

numpy.ndarray, numpy.ndarray

compute_stress(data, layout, n_neighbors, eps=1e-06)[source]

Compute the stress of an embedding based on the distances in the original high-dimensional space and the embedded space, using a ratio of distances.

Parameters:
  • data ((numpy.ndarray) High-dimensional data points.)

  • layout ((numpy.ndarray) Embedded data points.)

  • n_neighbors ((int) Number of nearest neighbors to consider for each point.)

  • eps ((float) Parameter to prevent zero division if mean distortion is near zero, default 1e-6.)

Returns:

float

Return type:

The normalized stress value indicating the quality of the embedding.

compute_neighbor_score(data, layout, n_neighbors)[source]

Computes the neighborhood preservation score between high-dimensional data and its corresponding low-dimensional layout.

The function evaluates how well the neighborhood relationships are preserved when data is projected from a high-dimensional space to a lower-dimensional space using the K-nearest neighbors approach. This involves comparing the nearest neighbors in the original space with those in the reduced space.

Parameters:
  • data ((numpy.ndarray) A NumPy array of shape (n_samples, data_dim) containing) – the original high-dimensional data.

  • layout ((numpy.ndarray) A NumPy array of shape (n_samples, embed_dim) containing) – the lower-dimensional embedding of the data.

  • n_neighbors ((int) The number of nearest neighbors to consider for each data point.)

Returns:

list

  • neighbor_mean: (float) The mean of the neighborhood preservation scores.

  • neighbor_std: (float) The standard deviation of the neighborhood preservation scores.

Return type:

A list containing two floats:

compute_local_metrics(data, layout, n_neighbors, memory_efficient=None)[source]

Compute local metrics of the (data, layout) pair.

Parameters:
  • data ((numpy.ndarray) High-dimensional data points.)

  • layout ((numpy.ndarray) Low-dimensional data points corresponding to the high-dimensional data.)

  • n_neighbors ((int) Number of closest neighbors for the kNN graph.)

  • memory_efficient ((bool or None) If True, use memory-efficient algorithms for large datasets.) – If None, automatically determine based on dataset size.

Returns:

dict

Return type:

A dictionary containing computed scores of each type (stress, neighborhood preservation).

threshold_subsample(*arrays, threshold, rng_key)[source]

Subsample multiple arrays based on a specified threshold. The function generates random numbers and selects the samples where the random number is less than the threshold.

Parameters:
  • *arrays ((tuple of numpy.ndarray)) – The input data arrays to be subsampled. Each array should have the same number of samples (rows).

  • threshold ((float)) – Probability threshold for subsampling; only samples with generated random numbers below this value are kept.

  • rng_key (Random key or random generator used for generating random numbers, ensuring reproducibility.)

Returns:

tuple

Return type:

A tuple containing the subsampled arrays in the same order as the input arrays.

diagrams(data, layout, max_dim, subsample_threshold, rng_key)[source]

Generate persistence diagrams for high-dimensional and low-dimensional data up to a specified dimension, after subsampling both datasets based on a threshold. The subsampling is performed to reduce the dataset size and potentially highlight more relevant features when computing topological summaries.

Parameters:
  • data ((numpy.ndarray) High-dimensional data points.)

  • layout ((numpy.ndarray) Low-dimensional data points corresponding to the high-dimensional data.)

  • max_dim ((int) Maximum dimension of homology groups to compute.)

  • subsample_threshold ((float) Threshold used for subsampling the data points.)

  • rng_key (Random key used for generating random numbers for subsampling, ensuring reproducibility.)

Returns:

dict – for the respective high-dimensional and low-dimensional datasets.

Return type:

A dictionary containing two keys, ‘data’ and ‘layout’, each associated with arrays of persistence diagrams

betti_curve(diagram, n_steps=100)[source]

Compute the Betti curve from a persistence diagram, which is a function of the number of features that persist at different filtration levels. This curve provides a summary of topological features across scales.

Parameters:
  • diagram ((list of tuples) A persistence diagram represented as a list of tuples (birth, death) indicating) – the range over which each topological feature persists.

  • n_steps ((int, optional) The number of steps or points in the filtration range at which to evaluate the Betti number.)

Returns:

tuple

  • The first array represents the evenly spaced filtration values.

  • The second array represents the Betti numbers at each filtration value.

Return type:

A tuple of two numpy arrays:

compute_dtw(axis_x_hd, axis_y_hd, axis_x_ld, axis_y_ld, norm_factor=1.0)[source]

Compute normalized Dynamic Time Warp (DTW) distance (using Euclidean metric) between two Betti curves represented as time series with time dimension x and values y.

Parameters:
  • axis_x_hd ((numpy.ndarray) Time axis of the high-dimensional Betti curve.)

  • axis_y_hd ((numpy.ndarray) Values of the high-dimensional Betti curve.)

  • axis_x_ld ((numpy.ndarray) Time axis of the low-dimensional Betti curve.)

  • axis_y_ld ((numpy.ndarray) Values of the low-dimensional Betti curve.)

  • norm_factor ((float) Normalization factor, default 1.0.)

Returns:

float

Return type:

Normalized DTW distance between two Betti curves.

compute_twed(axis_x_hd, axis_y_hd, axis_x_ld, axis_y_ld, norm_factor=1.0)[source]

Compute normalized Time Warp Edit Distance (TWED) distance using Euclidean metric between two Betti curves represented as time series with time dimension x and values y.

Parameters:
  • axis_x_hd ((numpy.ndarray) Time axis of the high-dimensional Betti curve.)

  • axis_y_hd ((numpy.ndarray) Values of the high-dimensional Betti curve.)

  • axis_x_ld ((numpy.ndarray) Time axis of the low-dimensional Betti curve.)

  • axis_y_ld ((numpy.ndarray) Values of the low-dimensional Betti curve.)

  • norm_factor ((float) Normalization factor, default 1.0.)

Returns:

float

Return type:

Normalized TWED distance between two Betti curves.

compute_emd(axis_x_hd, axis_y_hd, axis_x_ld, axis_y_ld, adjust_mass=False, norm_factor=1.0)[source]

Compute normalized Earth Mover Distance (EMD) distance (using Euclidean metric) between two Betti curves represented as time series with time dimension x and values y.

Parameters:
  • axis_x_hd ((numpy.ndarray) Time axis of the high-dimensional Betti curve.)

  • axis_y_hd ((numpy.ndarray) Values of the high-dimensional Betti curve.)

  • axis_x_ld ((numpy.ndarray) Time axis of the low-dimensional Betti curve.)

  • axis_y_ld ((numpy.ndarray) Values of the low-dimensional Betti curve.)

  • adjust_mass ((bool) Use to adjust mass (by default, EMD is computed for unit mass curves);) – default False.

  • norm_factor ((float) Normalization factor, default 1.0.)

Returns:

float

Return type:

Normalized EMD distance between two Betti curves.

compute_wasserstein(diag_hd, diag_ld, norm_factor=1.0)[source]

Compute normalized Wasserstein distance between two persistence diagrams (usually one of high-dimensional data and one of low-dimensional data).

Parameters:
  • diag_hd ((list of tuples) Persistence diagram for the high-dimensional data.)

  • diag_ld ((list of tuples) Persistence diagram for the low-dimensional data.)

  • norm_factor ((float) Normalization factor, default 1.0.)

Returns:

float

Return type:

Normalized Wasserstein distance between persistence diagrams.

compute_bottleneck(diag_hd, diag_ld, norm_factor=1.0)[source]

Compute normalized bottleneck distance between two persistence diagrams (usually one of high-dimensional data and one of low-dimensional data).

Parameters:
  • diag_hd ((list of tuples) Persistence diagram for the high-dimensional data.)

  • diag_ld ((list of tuples) Persistence diagram for the low-dimensional data.)

  • norm_factor ((float) Normalization factor, default 1.0.)

Returns:

float

Return type:

Normalized bottleneck distance between persistence diagrams.

compute_global_metrics(data, layout, dimension, subsample_threshold, rng_key, n_steps=100, metrics_only=True)[source]

Compute and compare persistence metrics between high-dimensional and low-dimensional data representations. The function calculates the Dynamic Time Warp (DTW), Time Warp Edit Distance (TWED), and Earth Mover Distance based on Betti curves derived from persistence diagrams. The function also calculate the Wasserstein distance and the bottleneck distance based on persistence diagrams. This evaluation helps quantify the topological fidelity of dimensionality reduction.

Parameters:
  • data ((numpy.ndarray) High-dimensional data points.)

  • layout ((numpy.ndarray) Low-dimensional data points corresponding to the high-dimensional data.)

  • dimension ((int) The maximum dimension for which to compute persistence diagrams.)

  • subsample_threshold ((float) Threshold used for subsampling the data.)

  • rng_key (Random key used for generating random numbers for subsampling, ensuring reproducibility.)

  • n_steps ((int, optional) The number of steps or points in the filtration range for computing Betti curves.)

  • metrics_only ((bool) If True, return metrics only; otherwise diagrams and Betti curves are also returned;) – default True.

Returns:

  • If metrics_only is True – dict(dict): A dictionary containing one item ‘metrics’ that is a dictionary of lists of computed distances for each of the metrics (DTW, TWED, EMD, Wasserstein, and bottleneck). Each list is populated according to the dimensions in which the distances were computed.

  • If metrics_only is False – dict(dict, dict, dict): A dictionary containing three items: - ‘metrics’: A dictionary of metrics, as described above; - ‘diags’: A dictionary of diagrams for the initial data and for the layout; - ‘bettis’: A dictionary of Betti curves for the initial data and for the layout. Each dictionary is a dictionary of lists. Each list is populated according to the dimensions in which the distances, diagrams, or curves were computed.

compute_svm_accuracy(X, y, test_size=0.3, reg_param=1.0, max_iter=100, random_state=42)[source]

Compute linear SVM classifier accuracy for given labelled data X with labels y.

Parameters:
  • X ((numpy.ndarray) Data.)

  • y ((numpy.ndarray) Data labels.)

  • test_size ((float) Test size (between 0.0 and 1.0) for the train / test split, default 0.3.)

  • reg_param ((float) Regularization parameter for SVM, default 1.0.)

  • max_iter ((int) Maximal number of iterations for SVM training, default 100.)

  • random_state ((int) Random state for reproducibility, default 42.)

Returns:

float

Return type:

Accuracy of the linear SVM model on the test set.

compute_svm_score(data, layout, labels, subsample_threshold, rng_key, **kwargs)[source]

Compute SVM score (context preservation measure) by comparing linear SVM classifier accuracies on the high-dimensional data and on the low-dimensional embedding.

Parameters:
  • data ((numpy.ndarray) High-dimensional data.)

  • layout ((numpy.ndarray) Low-dimensional embedding.)

  • labels ((numpy.ndarray) Data labels.)

  • subsample_threshold ((float) Threshold used for subsampling the data.)

  • rng_key (Random key used for generating random numbers for subsampling, ensuring reproducibility.)

  • kwargs (Other keyword arguments used by the various scores above.)

Returns:

float

Return type:

SVM context preservation score.

compute_knn_accuracy(X, y, n_neighbors=16, test_size=0.3, random_state=42)[source]

Compute kNN classifier accuracy for given labelled data X with labels y.

Parameters:
  • X ((numpy.ndarray) Data.)

  • y ((numpy.ndarray) Data labels.)

  • test_size ((float) Test size (between 0.0 and 1.0) for the train / test split, default 0.3.)

  • n_neighbors ((int) Number of neighbors for kNN classification, default 16.)

  • random_state ((int) Random state for reproducibility, default 42.)

Returns:

accuracy

Return type:

(float) Accuracy of the KNN model on the test set.

compute_knn_score(data, layout, labels, n_neighbors=16, **kwargs)[source]

Compute kNN score (context preservation measure) by comparing kNN classifier accuracies on the high-dimensional data and on the low-dimensional embedding.

Parameters:
  • data ((numpy.ndarray) High-dimensional data.)

  • layout ((numpy.ndarray) Low-dimensional embedding.)

  • labels ((numpy.ndarray) Data labels.)

  • n_neighbors ((int) Number of nearest neighbors for kNN classifier, default 16.)

  • kwargs (Other keyword arguments used by the various scores above.)

Returns:

float

Return type:

kNN context preservation score.

compute_quality_measures(data, layout, n_neighbors=None)[source]

Compute quality measures for assessing the quality of dimensionality reduction.

This function calculates various metrics that evaluate how well the low-dimensional representation preserves important properties of the high-dimensional data.

Parameters:
Returns:

Dictionary of quality measures including: - trustworthiness: Measures if points that are close in the embedding are also close in original space - continuity: Measures if points that are close in original space are also close in the embedding - shepard_correlation: Correlation between pairwise distances in original and embedded spaces

Return type:

dict

compute_context_measures(data, layout, labels, subsample_threshold, n_neighbors, rng_key, **kwargs)[source]

Compute measures of how well the embedding preserves the context of the data.

Parameters:
  • data (numpy.ndarray) – High-dimensional data points.

  • layout (numpy.ndarray) – Low-dimensional embedding of the data.

  • labels (numpy.ndarray) – Data labels needed for context preservation analysis.

  • subsample_threshold (float) – Threshold used for subsampling the data.

  • n_neighbors (int) – Number of neighbors for the kNN graph.

  • rng_key (jax.random.PRNGKey) – Random key for reproducible subsampling.

  • **kwargs – Additional keyword arguments for the scoring functions.

Returns:

Dictionary of context preservation measures, including SVM and kNN classification performance comparisons.

Return type:

dict