# dire.py
"""
Provides the main class for dimensionality reduction.
The DiRe (Dimensionality Reduction) class implements a modern approach to
dimensionality reduction, leveraging JAX for efficient computation. It uses
force-directed layout techniques combined with k-nearest neighbor graph
construction to generate meaningful low-dimensional embeddings of
high-dimensional data.
"""
#
# Imports
#
import sys
import os
import gc
import functools
# JAX-related imports
import jax
from jax import jit, lax, vmap, random, device_put, device_get
import jax.numpy as jnp
# Scientific and numerical libraries
import numpy as np
from scipy.sparse.csgraph import laplacian
from scipy.optimize import curve_fit
from scipy.sparse import csr_matrix
from scipy.sparse.linalg import eigsh
from sklearn.decomposition import PCA, KernelPCA
# Data structures and visualization
import pandas as pd
import plotly.express as px
# Utilities
from tqdm import tqdm
from loguru import logger
# Nearest neighbor search
from .hpindex import HPIndex
#
# Double precision support
#
jax.config.update("jax_enable_x64", True)
#
# Main class for Dimensionality Reduction
#
[docs]
class DiRe:
"""
Dimension Reduction (DiRe) is a class designed to reduce the dimensionality of high-dimensional
data using various embedding techniques and optimization algorithms. It supports embedding
initialization methods such as random and spectral embeddings and utilizes k-nearest neighbors
for manifold learning.
Parameters
----------
dimension: (int) Embedding dimension, default 2.
n_neighbors: (int) Number of nearest neighbors to consider for each point, default 16.
init_embedding_type: (str)
Method to initialize the embedding; choices are:
- 'random' for random projection based on the Johnson-Lindenstrauss lemma;
- 'spectral' for spectral embedding (with sim_kernel as similarity kernel);
- 'pca' for PCA embedding (classical, no kernel).
By default, 'random'.
sim_kernel: (callable)
A similarity kernel function that transforms a distance metric to a similarity score.
The function should have the form `lambda distance: float -> similarity: float`; default `None`.
max_iter_layout: (int)
Maximum number of iterations to run the layout optimization, default 128.
min_dist: (float)
Minimum distance scale for distribution kernel, default 0.01.
spread: (float)
Spread of the distribution kernel, default 1.0.
cutoff: (float)
Cutoff for clipping forces during layout optimization, default 42.0.
n_sample_dirs: (int)
Number of directions to sample in random sampling, default 8.
sample_size: (int)
Number of samples per direction in random sampling, default 16.
neg_ratio: (int)
Ratio of negative to positive samples in random sampling, default 8.
my_logger: (logger.Logger or `None`)
Custom logger for logging events; if None, a default logger is created, default `None`.
verbose: (bool)
Flag to enable verbose output, default `True`.
Attributes
----------
dimension: int
Target dimensionality of the output space.
n_neighbors: int
Number of neighbors to consider in the k-nearest neighbors graph.
init_embedding_type: str
Chosen method for initial embedding.
sim_kernel: callable
Similarity kernel function to be used if 'init_embedding_type' is 'spectral', by default `None`.
pca_kernel: callable
Kernel function to be used if 'init_embedding_type' is 'pca', by default `None`.
max_iter_layout: int
Maximum iterations for optimizing the layout.
min_dist: float
Minimum distance for repulsion used in the distribution kernel.
spread: float
Spread between the data points used in the distribution kernel.
cutoff: float
Maximum cutoff for forces during optimization.
n_sample_dirs: int
Number of random directions sampled.
sample_size: int or 'auto'
Number of samples per random direction, unless chosen automatically with 'auto'.
neg_ratio: int
Ratio of negative to positive samples in the sampling process.
logger: logger.Logger or `None`
Logger used for logging informational and warning messages.
verbose: bool
Logger output flag (True = output logger messages, False = flush to null)
memm: dictionary or `None`
Memory manager: a dictionary with the batch / memory tile size for different
hardware architectures. Accepts 'tpu', 'gpu' and 'other' as keys. Values must
be positive integers.
mpa: bool
Mixed Precision Arithmetic flag (True = use MPA, False = always use float64)
Methods
-------
fit_transform(data)
A convenience method that fits the model and then transforms the data.
The separate `fit` and `transform` methods can only be used one after
another because dimension reduction is applied to the dataset as a whole.
visualize(labels=None, point_size=2)
Visualizes the transformed data, optionally using labels to color the points.
"""
[docs]
def __init__(self,
dimension=2,
n_neighbors=16,
init_embedding_type='random',
sim_kernel=None,
pca_kernel=None,
max_iter_layout=128,
min_dist=1e-2,
spread=1.0,
cutoff=42.0,
n_sample_dirs=8,
sample_size=16,
neg_ratio=8,
my_logger=None,
verbose=True,
memm=None,
mpa=True):
"""
Class constructor
"""
#
self.dimension = dimension
""" Embedding dimension """
self.n_neighbors = n_neighbors
""" Number of neighbors for kNN computations"""
self.init_embedding_type = init_embedding_type
""" Type of the initial embedding (PCA, random, spectral) """
self.sim_kernel = sim_kernel
""" Similarity kernel """
self.pca_kernel = pca_kernel
""" PCA kernel """
self.max_iter_layout = max_iter_layout
""" Max iterations for the force layout """
self.min_dist = min_dist
""" Min distance between points in layout """
self.spread = spread
""" Layout spread """
self.cutoff = cutoff
""" Cutoff for layout displacement """
self.n_sample_dirs = n_sample_dirs
""" Number of sampling directions for layout"""
self.sample_size = sample_size
""" Sample size for attraction """
self.neg_ratio = neg_ratio
""" Ratio for repulsion sample size """
self._init_embedding = None
""" Initial embedding """
self._layout = None
""" Layout output """
self._a = None
""" Probability kernel parameter """
self._b = None
""" Probability kernel parameter """
self._data = None
""" Higher-dimensional data """
self._n_samples = None
""" Number of data points """
self._data_dim = None
""" Dimension of data """
self._distances_np = None
self._distances_jax = None
""" Distances in the kNN graph """
self._indices_np = None
self._indices_jax = None
""" Neighbor indices in the kNN graph """
self._nearest_neighbor_distances = None
""" Neighbor indices in the kNN graph, excluding the point itself """
self._row_idx = None
""" Row indices for nearest neighbors """
self._col_idx = None
""" Column indices for nearest neighbors """
self._adjacency = None
""" kNN adjacency matrix """
#
if my_logger is None:
logger.remove()
sink = sys.stdout if verbose else open(os.devnull, 'w', encoding='utf-8')
logger.add(sink, level="INFO")
self.logger = logger
""" System logger """
else:
self.logger = my_logger
# Memory manager to be adjusted for each particular type of hardware
# Below are some minimalist settings that may give less than satisfactory performance
self.memm = {'gpu': 16384, 'tpu': 8192, 'other': 8192} if memm is None else memm
# Using Mixed Precision Arithmetic flag (True = MPA, False = no MPA)
self.mpa = mpa
#
# Fitting the distribution kernel with given min_dist and spread
#
[docs]
def find_ab_params(self, min_dist=0.01, spread=1.0):
"""
Rational function approximation to the probabilistic t-kernel
"""
#
self.logger.info('find_ab_params ...')
def curve(x, a, b):
return 1.0 / (1.0 + a * x ** (2 * b))
#
xv = np.linspace(0, spread * 3, 300)
yv = np.zeros(xv.shape)
yv[xv < min_dist] = 1.0
yv[xv >= min_dist] = np.exp(-(xv[xv >= min_dist] - min_dist) / spread)
params, _ = curve_fit(curve, xv, yv)
#
self.logger.info(f'a = {params[0]}, b = {params[1]}')
self.logger.info('find_ab_params done ...')
#
return params[0], params[1]
#
# Fitting on data
#
[docs]
def fit(self, data):
"""
Fit the model to data: create the kNN graph and fit the probability kernel to force layout parameters.
Parameters
----------
data: (numpy.ndarray)
High-dimensional data to fit the model. Shape (n_samples, n_features).
Returns
-------
self: The DiRe instance fitted to data.
"""
#
self.logger.info('fit ...')
#
self._data = data
self._n_samples = self._data.shape[0]
self._data_dim = self._data.shape[1]
self.logger.info(f'Dimension {self._data_dim}, number of samples {self._n_samples}')
self.make_knn_adjacency()
self._a, self._b = self.find_ab_params(self.min_dist, self.spread)
#
self.logger.info('fit done ...')
#
return self
#
# Transform fitted data into lower-dimensional representation
#
#
# Computing the kNN adjacency matrix (sparse)
#
[docs]
def make_knn_adjacency(self, batch_size=None):
"""
Internal routine building the adjacency matrix for the kNN graph.
This method computes the k-nearest neighbors for each point in the dataset
and constructs a sparse adjacency matrix representing the kNN graph.
It attempts to use GPU acceleration if available, with a fallback to CPU.
For large datasets, it uses batching to limit memory usage.
Parameters
----------
batch_size : int or None, optional
Number of samples to process at once. If None, a suitable value
will be automatically determined based on dataset size.
The method sets the following instance attributes:
- distances: Distances to the k nearest neighbors (including self)
- indices: Indices of the k nearest neighbors (including self)
- nearest_neighbor_distances: Distances to the nearest neighbors (excluding self)
- row_idx, col_idx: Indices for constructing the sparse adjacency matrix
- adjacency: Sparse adjacency matrix of the kNN graph
"""
self.logger.info('make_knn_adjacency ...')
# Ensure data is in the right format for HPIndex
self._data = np.ascontiguousarray(self._data.astype(np.float64))
n_neighbors = self.n_neighbors + 1 # Including the point itself
# Determine appropriate batch size for memory efficiency
if batch_size is None:
# Process in chunks to reduce peak memory usage
if jax.devices()[0].platform == 'tpu':
batch_size = min(self.memm['tpu'], self._n_samples)
elif jax.devices()[0].platform == 'gpu':
batch_size = min(self.memm['gpu'], self._n_samples)
else:
batch_size = min(self.memm['other'], self._n_samples)
self.logger.info(f'Using batch size: {batch_size}')
self.logger.debug(f"[KNN] Using precision: {'float32' if self.mpa else 'float64'}")
if self.mpa:
self._indices_jax, self._distances_jax = HPIndex.knn_tiled(
self._data, self._data, n_neighbors, batch_size, batch_size, dtype=jnp.float32)
else:
self._indices_jax, self._distances_jax = HPIndex.knn_tiled(
self._data, self._data, n_neighbors, batch_size, batch_size, dtype=jnp.float64)
# Wait until ready
self._indices_jax.block_until_ready()
self._distances_jax.block_until_ready()
# Store results in numpy
self._indices_np = device_get(self._indices_jax).astype(np.int64)
self._distances_np = device_get(self._distances_jax).astype(np.float64)
# Extract nearest neighbor distances (excluding self)
self._nearest_neighbor_distances = self._distances_np[:, 1:]
# Create indices for sparse matrix construction
self._row_idx = np.repeat(np.arange(self._n_samples), n_neighbors)
self._col_idx = self._indices_np.ravel()
# Create sparse adjacency matrix (memory efficient)
data_values = self._distances_np.ravel()
self._adjacency = csr_matrix(
(data_values, (self._row_idx, self._col_idx)),
shape=(self._n_samples, self._n_samples)
)
# Clean up resources
gc.collect()
self.logger.info('make_knn_adjacency done ...')
#
# Initialize embedding using different techniques
#
[docs]
def do_pca_embedding(self):
"""
Initialize embedding using Principal Component Analysis (PCA).
This method creates an initial embedding of the data using PCA, which finds
a linear projection of the high-dimensional data into a lower-dimensional space
that maximizes the variance. If a kernel is specified, Kernel PCA is used instead,
which can capture nonlinear relationships.
Sets the init_embedding attribute with the PCA projection of the data.
"""
self.logger.info('do_pca_embedding ...')
if self.pca_kernel is not None:
# Use Kernel PCA for nonlinear dimensionality reduction
self.logger.info('Using kernelized PCA embedding...')
pca = KernelPCA(
n_components=self.dimension,
kernel=self.pca_kernel
)
self._init_embedding = pca.fit_transform(self._data)
else:
# Use standard PCA for linear dimensionality reduction
self.logger.info('Using standard PCA embedding...')
pca = PCA(n_components=self.dimension)
self._init_embedding = pca.fit_transform(self._data)
self.logger.info('do_pca_embedding done ...')
[docs]
def do_spectral_embedding(self):
"""
Initialize embedding using Spectral Embedding.
This method creates an initial embedding of the data using spectral embedding,
which is based on the eigenvectors of the graph Laplacian. It relies on the
kNN graph structure to find a lower-dimensional representation that preserves
local relationships.
If a similarity kernel is specified, it is applied to transform the distances
in the adjacency matrix before computing the Laplacian.
Sets the init_embedding attribute with the spectral embedding of the data.
"""
self.logger.info('do_spectral_embedding ...')
# Apply similarity kernel if provided
if self.sim_kernel is not None:
self.logger.info('Applying similarity kernel to adjacency matrix...')
# Transform distances using the similarity kernel
data_values = self.sim_kernel(self._adjacency.data)
# Create a new adjacency matrix with transformed values
adj_mat = csr_matrix(
(data_values, (self._row_idx, self._col_idx)),
shape=(self._n_samples, self._n_samples)
)
else:
adj_mat = self._adjacency
# Make the adjacency matrix symmetric by adding it to its transpose
symmetric_adj = adj_mat + adj_mat.T
# Compute the normalized Laplacian
lap = laplacian(symmetric_adj, normed=True)
# Find the k smallest eigenvectors (k = dimension + 1)
k = self.dimension + 1
_, eigenvectors = eigsh(lap, k, which='SM')
# Skip the first eigenvector (corresponds to constant function)
self._init_embedding = eigenvectors[:, 1:k]
self.logger.info('do_spectral_embedding done ...')
[docs]
def do_random_embedding(self):
"""
Initialize embedding using Random Projection.
This method creates an initial embedding of the data using random projection,
which is a simple and computationally efficient technique for dimensionality
reduction. It projects the data onto a randomly generated basis, providing
a good starting point for further optimization.
Random projection is supported by the Johnson-Lindenstrauss lemma, which
guarantees that the distances between points are approximately preserved
under certain conditions.
Sets the init_embedding attribute with the random projection of the data.
"""
self.logger.info('do_random_embedding ...')
# Create a random projection matrix
key = random.PRNGKey(13) # Fixed seed for reproducibility
rand_basis = random.normal(key, (self.dimension, self._data_dim))
# Move data and projection matrix to device memory
data_matrix = device_put(self._data)
rand_basis = device_put(rand_basis)
# Project data onto random basis
self._init_embedding = data_matrix @ rand_basis.T
self.logger.info('do_random_embedding done ...')
#
# Efficient sampling for force-directed layout
#
[docs]
def do_rand_sampling(self, key, arr, n_samples, n_dirs, neg_ratio):
"""
Sample points for force calculation using random projections.
This method implements an efficient sampling strategy to identify points
for applying attractive and repulsive forces during layout optimization.
It uses random projections to quickly identify nearby points in different
directions, and also adds random negative samples for repulsion.
Parameters
----------
key : jax.random.PRNGKey
Random number generator key
arr : jax.numpy.ndarray
Array of current point positions
n_samples : int
Number of samples to take in each direction
n_dirs : int
Number of random directions to sample
neg_ratio : int
Ratio of negative samples to positive samples
Returns
-------
jax.numpy.ndarray
Array of sampled indices for force calculations
Notes
-----
The sampling strategy works as follows:
1. Generate n_dirs random unit vectors
2. Project the points onto each vector
3. For each point, take the n_samples closest points in each direction
4. Add random negative samples for repulsion
5. Combine all sampled indices
This approach is more efficient than a full nearest neighbor search
while still capturing the important local relationships.
"""
self.logger.info('do_rand_sampling ...')
sampled_indices_list = []
arr_len = len(arr)
# Get random unit vectors for projections
key, subkey = random.split(key)
direction_vectors = rand_directions(subkey, self.dimension, n_dirs)
# For each direction, sample points based on projections
for vec in direction_vectors:
# Project points onto the direction vector
arr_proj = vec @ arr.T
# Sort indices by projection values
indices_sort = jnp.argsort(arr_proj)
# For each point, take n_samples points around it in sorted order
vmap_get_slice = vmap(get_slice, in_axes=(None, None, 0))
indices = vmap_get_slice(indices_sort, n_samples, jnp.arange(arr_len))
# Reorder indices back to original ordering
indices = indices[indices_sort]
# Add to list of sampled indices
sampled_indices_list.append(indices)
# Generate random negative samples for repulsion
n_neg_samples = int(neg_ratio * n_samples)
key, subkey = random.split(key)
neg_indices = random.randint(subkey, (arr_len, n_neg_samples), 0, arr_len)
sampled_indices_list.append(neg_indices)
# Combine all sampled indices
sampled_indices = jnp.concatenate(sampled_indices_list, axis=-1)
self.logger.info('do_rand_sampling done ...')
return sampled_indices
#
# Create layout using force-directed optimization
#
[docs]
def do_layout(self, large_dataset_mode=None, force_cpu=False):
"""
Optimize the layout using force-directed placement with JAX kernels.
This method takes the initial embedding and iteratively refines it using
attractive and repulsive forces to create a meaningful low-dimensional
representation of the high-dimensional data. The algorithm applies:
1. Attraction forces between points that are neighbors in the high-dimensional space
2. Repulsion forces between randomly sampled points in the low-dimensional space
3. Gradual cooling (decreasing force impact) as iterations progress
The final layout is normalized to have zero mean and unit standard deviation.
Parameters
----------
large_dataset_mode : bool or None, optional
If True, use memory-efficient techniques for large datasets.
If None, automatically determine based on dataset size.
force_cpu : bool, optional
If True, force computations on CPU instead of GPU, which can
be helpful for large datasets that exceed GPU memory.
"""
self.logger.info('do_layout ...')
# Setup parameters
cutoff = jnp.array([self.cutoff])
num_iterations = self.max_iter_layout
# Handle automatic batch size calculation if needed
sample_size = self.sample_size
if sample_size == 'auto':
# Scale batch size based on dataset size and neighborhood size
sample_size = int(self.n_neighbors * np.log(self._n_samples))
# Ensure batch size is reasonable (not too small or large)
sample_size = max(min(512, sample_size), 32)
# Determine if we should use memory-efficient mode for large datasets
if large_dataset_mode is None:
large_dataset_mode = ((self._n_samples > 65536) or
(jax.devices()[0].platform == 'tpu'))
# Other parameters
n_dirs = self.n_sample_dirs
neg_ratio = self.neg_ratio
# Debug initial embedding precision
self.logger.debug(f"[LAYOUT] Initial embedding precision: {self._init_embedding.dtype}")
# we shall use force_cpu only as a flag passed to the routine
# force_cpu = force_cpu or large_dataset_mode and (jax.devices()[0].platform == 'tpu')
# Initialize and normalize positions
if force_cpu:
self.logger.info("Forcing computations on CPU")
cpu_device = jax.devices('cpu')[0]
init_pos_jax = device_put(self._init_embedding, device=cpu_device)
neighbor_indices_jax = device_put(self._indices_np, device=cpu_device)
else:
init_pos_jax = device_put(self._init_embedding)
neighbor_indices_jax = device_put(self._indices_jax)
init_pos_jax -= init_pos_jax.mean(axis=0) # Center positions
init_pos_jax /= init_pos_jax.std(axis=0) # Normalize variance
# Set random seed for reproducibility
key = random.PRNGKey(42)
# Optimization loop
for iter_id in tqdm(range(num_iterations)):
logger.debug(f'Iteration {iter_id + 1}')
# Sample random points for repulsion
sample_indices_jax = self.do_rand_sampling(
key,
init_pos_jax,
sample_size,
n_dirs,
neg_ratio
)
if force_cpu:
cpu_device = jax.devices('cpu')[0]
sample_indices_jax = device_put(sample_indices_jax, device=cpu_device)
else:
sample_indices_jax = device_put(sample_indices_jax)
# Split computation for memory efficiency if needed
if large_dataset_mode:
# Process in chunks to reduce peak memory usage
if jax.devices()[0].platform == 'tpu':
chunk_size = min(self.memm['tpu'], self._n_samples)
elif jax.devices()[0].platform == 'gpu':
chunk_size = min(self.memm['gpu'], self._n_samples)
else:
chunk_size = min(self.memm['other'], self._n_samples)
# this is actually inefficient, but let's postpone
all_forces = []
self.logger.info(f"Using memory tiling with tile size: {chunk_size}")
for chunk_start in range(0, self._n_samples, chunk_size):
chunk_end = min(chunk_start + chunk_size, self._n_samples)
chunk_indices = jnp.arange(chunk_start, chunk_end)
# Process this chunk using our kernelized function
chunk_force = self._compute_forces(
init_pos_jax,
chunk_indices,
neighbor_indices_jax[chunk_indices],
sample_indices_jax[chunk_indices],
alpha=1.0 - iter_id / num_iterations
)
all_forces.append(chunk_force)
# Explicitly clean up to reduce memory pressure
gc.collect()
# Combine results from all chunks
net_force = jnp.concatenate(all_forces, axis=0)
else:
# Process all points at once for smaller datasets
net_force = self._compute_forces(
init_pos_jax,
jnp.arange(self._n_samples),
neighbor_indices_jax,
sample_indices_jax,
alpha=1.0 - iter_id / num_iterations
)
# Clip forces to prevent extreme movements
net_force = jnp.clip(net_force, -cutoff, cutoff)
# Update positions
init_pos_jax += net_force
# Ensure we're not accumulating unnecessary computation graphs in JAX
init_pos_jax.block_until_ready()
# Clean up resources
gc.collect()
# Normalize final layout
init_pos_jax -= init_pos_jax.mean(axis=0)
init_pos_jax /= init_pos_jax.std(axis=0)
# Store final layout
self._layout = np.asarray(init_pos_jax)
# Clear any cached values to free memory
gc.collect()
self.logger.info('do_layout done ...')
# Modified _compute_forces method to use the kernel
def _compute_forces(self, positions, chunk_indices, neighbor_indices, sample_indices, alpha=1.0):
"""
Compute attractive and repulsive forces for points using JAX kernels.
This method uses JAX-optimized kernels to efficiently compute forces
between points during layout optimization.
Parameters
----------
positions : jax.numpy.ndarray
Current point positions
chunk_indices : jax.numpy.ndarray
Current batch indices
neighbor_indices : jax.numpy.ndarray
Indices of neighbors for attractive forces
sample_indices : jax.numpy.ndarray
Indices of points for repulsive forces
alpha : float
Cooling factor that scales force magnitude
Returns
-------
jax.numpy.ndarray
Net force vectors for each point
"""
if self.mpa:
positions = positions.astype(jnp.float32)
else:
positions = positions.astype(jnp.float64)
self.logger.debug(f"[FORCE] Computing forces on device: {positions.device}")
self.logger.debug(f"[FORCE] Using precision: {positions.dtype}")
# Call the JAX-optimized kernel
return compute_forces_kernel(
positions,
chunk_indices,
neighbor_indices,
sample_indices,
alpha,
self._a,
self._b,
)
#
# Visualize the layout
#
[docs]
def visualize(self, labels=None, point_size=2, title=None, colormap=None, width=800, height=600, opacity=0.7):
"""
Generate an interactive visualization of the data in the transformed space.
This method creates a scatter plot visualization of the embedded data, supporting
both 2D and 3D visualizations depending on the specified dimension. Points can be
colored by provided labels for clearer visualization of clusters or categories.
Parameters
----------
labels : numpy.ndarray or None, optional
Labels for each data point to color the points in the visualization.
If None, all points will have the same color. Default is None.
point_size : int or float, optional
Size of points in the scatter plot. Default is 2.
title : str or None, optional
Title for the visualization. If None, a default title will be used. Default is None.
colormap : str or None, optional
Name of the colormap to use for labels (e.g., 'viridis', 'plasma').
If None, the default Plotly colormap will be used. Default is None.
width : int, optional
Width of the figure in pixels. Default is 800.
height : int, optional
Height of the figure in pixels. Default is 600.
opacity : float, optional
Opacity of the points (0.0 to 1.0). Default is 0.7.
Returns
-------
plotly.graph_objs._figure.Figure or None
A Plotly figure object if the visualization is successful;
None if no layout is available or dimension > 3.
Notes
-----
For 3D visualizations, you can rotate, zoom, and pan the plot interactively.
For both 2D and 3D, hover over points to see their coordinates and labels.
"""
# Check if layout is available
if self._layout is None:
self.logger.warning('visualize ERROR: no layout available')
return None
# Set default title if not provided
if title is None:
title = f"{self.init_embedding_type.capitalize()} Initialized {self.dimension}D Embedding"
# Common visualization parameters
vis_params = {
'color': 'label' if labels is not None else None,
'color_continuous_scale': colormap,
'opacity': opacity,
'title': title,
'hover_data': ['label'] if labels is not None else None,
}
# Create 2D visualization
if self.dimension == 2:
self.logger.info('visualize: 2D ...')
# Create dataframe for plotting
datadf = pd.DataFrame(self._layout, columns=['x', 'y'])
# Add labels if provided
if labels is not None:
datadf['label'] = labels
# Create scatter plot
fig = px.scatter(
datadf,
x='x',
y='y',
**vis_params
)
# Update layout
fig.update_layout(
width=width,
height=height,
xaxis_title='x',
yaxis_title='y',
)
# Create 3D visualization
elif self.dimension == 3:
self.logger.info('visualize: 3D ...')
# Create dataframe for plotting
datadf = pd.DataFrame(self._layout, columns=['x', 'y', 'z'])
# Add labels if provided
if labels is not None:
datadf['label'] = labels
# Create 3D scatter plot
fig = px.scatter_3d(
datadf,
x='x',
y='y',
z='z',
**vis_params
)
# Update layout
fig.update_layout(
width=width,
height=height,
scene={
"xaxis_title": 'x',
"yaxis_title": 'y',
"zaxis_title": 'z',
}
)
# Return None for higher dimensions
else:
self.logger.warning('visualize ERROR: dimension > 3')
return None
# Update marker properties
fig.update_traces(marker={"size": point_size})
return fig
#
# Kernel for force-directed layout
#
[docs]
@functools.partial(jit, static_argnums=(5, 6))
def compute_forces_kernel(positions, chunk_indices, neighbor_indices, sample_indices, alpha, a, b):
"""
JAX-optimized kernel for computing attractive and repulsive forces.
Parameters
----------
positions : jax.numpy.ndarray
Current point positions
chunk_indices : jax.numpy.ndarray
Current batch indices
neighbor_indices : jax.numpy.ndarray
Indices of neighbors for attractive forces
sample_indices : jax.numpy.ndarray
Indices of points for repulsive forces
alpha : float
Cooling factor that scales force magnitude
a : float
Attraction parameter
b : float
Repulsion parameter
Returns
-------
jax.numpy.ndarray
Net force vectors for each point
"""
# ===== Attraction Forces =====
def compute_attraction(chunk_idx, neighbors_idx):
# Get positions of current point and its neighbors
point_pos = positions[chunk_idx]
neighbor_pos = positions[neighbors_idx]
# Compute position differences and distances
diff = neighbor_pos - point_pos
dist = jnp.linalg.norm(diff, axis=1, keepdims=True)
# Avoid division by zero
mask = dist > 0
direction = jnp.where(mask, diff / dist, 0.0)
# Compute attraction-repulsion coefficients
grad_coeff = jnp.where(
mask,
1.0 * jax_coeff_att(dist, a, b) + 1.0 * jax_coeff_rep(dist, a, b),
0.0
)
# Sum forces from all neighbors
return jnp.sum(grad_coeff * direction, axis=0)
# ===== Repulsion Forces =====
def compute_repulsion(chunk_idx, sample_idx):
# Get positions of current point and sampled points
point_pos = positions[chunk_idx]
sample_pos = positions[sample_idx]
# Compute position differences and distances
diff = sample_pos - point_pos
dist = jnp.linalg.norm(diff, axis=1, keepdims=True)
# Avoid division by zero
mask = dist > 0
direction = jnp.where(mask, diff / dist, 0.0)
# Compute repulsion coefficients
grad_coeff = jnp.where(mask, jax_coeff_rep(dist, a, b), 0.0)
# Sum forces from all sampled points
return jnp.sum(grad_coeff * direction, axis=0)
# Vectorize force computation across all points
attraction_forces = vmap(compute_attraction)(chunk_indices, neighbor_indices)
repulsion_forces = vmap(compute_repulsion)(chunk_indices, sample_indices)
# Combine forces with cooling factor
return alpha * (attraction_forces + repulsion_forces)
#
# Auxiliary functions for force-directed layout
#
[docs]
@jax.jit
def distribution_kernel(dist, a, b):
"""
Probability kernel that maps distances to similarity scores.
This is a rational function approximation of a t-distribution.
Parameters
----------
dist : jax.numpy.ndarray
Distance values
a : float
Scale parameter that controls the steepness of the distribution
b : float
Shape parameter that controls the tail behavior
Returns
-------
float or jax.numpy.ndarray
Similarity score(s) between 0 and 1
"""
return 1.0 / (1.0 + a * dist ** (2 * b))
# Helper functions for force calculations
[docs]
@jax.jit
def jax_coeff_att(dist, a, b):
"""JAX-optimized attraction coefficient function."""
return 1.0 * distribution_kernel(1/dist, a, b)
[docs]
@jax.jit
def jax_coeff_rep(dist, a, b):
"""JAX-optimized repulsion coefficient function."""
return -1.0 * distribution_kernel(dist, a, b)
[docs]
@functools.partial(jit, static_argnums=(1, 2))
def rand_directions(key, dim=2, num=100):
"""
Sample unit vectors in random directions.
Parameters
----------
key : jax.random.PRNGKey
Random number generator key
dim : int
Dimensionality of the vectors
num : int
Number of random directions to sample
Returns
-------
jax.numpy.ndarray
Array of shape (num, dim) containing unit vectors
"""
points = random.normal(key, (num, dim))
norms = jnp.sqrt(jnp.sum(points * points, axis=-1))
return points / norms[:, None]
[docs]
@functools.partial(jit, static_argnums=(1,))
def get_slice(arr, k, i):
"""
Extract a slice of size k centered around index i.
Parameters
----------
arr : jax.numpy.ndarray
Input array
k : int
Size of the slice
i : int
Center index position
Returns
-------
jax.numpy.ndarray
Slice of the input array
"""
return lax.dynamic_slice(arr, (i - k // 2,), (k,))