# dire.py
"""
Provides the main JAX-based class for dimensionality reduction.
The DiRe (Dimensionality Reduction) class implements a modern approach to
dimensionality reduction, leveraging JAX for efficient computation. It uses
force-directed layout techniques combined with k-nearest neighbor graph
construction to generate meaningful low-dimensional embeddings of
high-dimensional data.
This JAX implementation features:
- Fully vectorized force computation (no chunking for optimal performance)
- JIT compilation for mathematical operations
- Optimized for small to medium datasets (<50K points)
- Excellent CPU performance and research-friendly design
"""
#
# Imports
#
import functools
import gc
import os
import sys
from random import randint
# JAX-related imports
import jax
import jax.numpy as jnp
# Scientific and numerical libraries
import numpy as np
# Data structures and visualization
import pandas as pd
import plotly.express as px
from jax import device_get, device_put, jit, lax, random, vmap
from loguru import logger
from scipy.optimize import curve_fit
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import laplacian
from scipy.sparse.linalg import eigsh
from sklearn.base import TransformerMixin
from sklearn.decomposition import PCA, KernelPCA
# Utilities
from tqdm import tqdm
# Nearest neighbor search
from .hpindex import HPIndex
#
# Double precision support
#
jax.config.update("jax_enable_x64", True)
#
# Main class for Dimensionality Reduction
#
[docs]
class DiRe(TransformerMixin):
"""
Dimension Reduction (DiRe) is a class designed to reduce the dimensionality of high-dimensional
data using various embedding techniques and optimization algorithms. It supports embedding
initialization methods such as random and spectral embeddings and utilizes k-nearest neighbors
for manifold learning.
Parameters
----------
n_components: (int) Embedding dimension, default 2.
n_neighbors: (int) Number of nearest neighbors to consider for each point, default 16.
init: (str)
Method to initialize the embedding; choices are:
- 'random' for random projection based on the Johnson-Lindenstrauss lemma;
- 'spectral' for spectral embedding (with sim_kernel as similarity kernel);
- 'pca' for PCA embedding (classical, no kernel).
By default, 'random'.
sim_kernel: (callable)
A similarity kernel function that transforms a distance metric to a similarity score.
The function should have the form `lambda distance: float -> similarity: float`; default `None`.
max_iter_layout: (int)
Maximum number of iterations to run the layout optimization, default 128.
min_dist: (float)
Minimum distance scale for distribution kernel, default 0.01.
spread: (float)
Spread of the distribution kernel, default 1.0.
cutoff: (float)
Cutoff for clipping forces during layout optimization, default 42.0.
n_sample_dirs: (int)
Number of directions to sample in random sampling, default 8.
sample_size: (int)
Number of samples per direction in random sampling, default 16.
neg_ratio: (int)
Ratio of negative to positive samples in random sampling, default 8.
my_logger: (logger.Logger or `None`)
Custom logger for logging events; if None, a default logger is created, default `None`.
verbose: (bool)
Flag to enable verbose output, default `True`.
random_state: (int)
Random seed to make stochastic computations reproducible.
Attributes
----------
n_components: int
Target dimensionality of the output space.
n_neighbors: int
Number of neighbors to consider in the k-nearest neighbors graph.
init: str
Chosen method for initial embedding.
sim_kernel: callable
Similarity kernel function to be used if 'init' is 'spectral', by default `None`.
pca_kernel: callable
Kernel function to be used if 'init' is 'pca', by default `None`.
max_iter_layout: int
Maximum iterations for optimizing the layout.
min_dist: float
Minimum distance for repulsion used in the distribution kernel.
spread: float
Spread between the data points used in the distribution kernel.
cutoff: float
Maximum cutoff for forces during optimization.
n_sample_dirs: int
Number of random directions sampled.
sample_size: int or 'auto'
Number of samples per random direction, unless chosen automatically with 'auto'.
batch_size : int or None, optional
Number of samples to process at once. If None, a suitable value
will be automatically determined based on dataset size.
neg_ratio: int
Ratio of negative to positive samples in the sampling process.
logger: logger.Logger or `None`
Logger used for logging informational and warning messages.
verbose: bool
Logger output flag (True = output logger messages, False = flush to null)
memm: dictionary or `None`
Memory manager: a dictionary with the batch / memory tile size for different
hardware architectures. Accepts 'tpu', 'gpu' and 'other' as keys. Values must
be positive integers.
mpa: bool
Mixed Precision Arithmetic flag (True = use MPA, False = always use float64)
Methods
-------
fit_transform(data)
A convenience method that fits the model and then transforms the data.
The separate `fit` and `transform` methods can only be used one after
another because dimension reduction is applied to the dataset as a whole.
visualize(labels=None, point_size=2)
Visualizes the transformed data, optionally using labels to color the points.
"""
[docs]
def __init__(
self,
n_components=2,
n_neighbors=16,
init="random",
sim_kernel=None,
pca_kernel=None,
max_iter_layout=128,
min_dist=1e-2,
spread=1.0,
cutoff=42.0,
n_sample_dirs=8,
sample_size=16,
batch_size=None,
neg_ratio=8,
my_logger=None,
verbose=True,
memm=None,
mpa=True,
random_state=None,
):
"""
Class constructor
"""
#
self.n_components = n_components
""" Embedding dimension """
self.n_neighbors = n_neighbors
""" Number of neighbors for kNN computations"""
self.init = init
""" Type of the initial embedding (PCA, random, spectral) """
self.sim_kernel = sim_kernel
""" Similarity kernel """
self.pca_kernel = pca_kernel
""" PCA kernel """
self.max_iter_layout = max_iter_layout
""" Max iterations for the force layout """
self.min_dist = min_dist
""" Min distance between points in layout """
self.spread = spread
""" Layout spread """
self.cutoff = cutoff
""" Cutoff for layout displacement """
self.n_sample_dirs = n_sample_dirs
""" Number of sampling directions for layout"""
self.sample_size = sample_size
""" Sample size for attraction """
self.neg_ratio = neg_ratio
""" Ratio for repulsion sample size """
self._init_embedding = None
""" Initial embedding """
self._layout = None
""" Layout output """
self._a = None
""" Probability kernel parameter """
self._b = None
""" Probability kernel parameter """
self._data = None
""" Higher-dimensional data """
self._n_samples = None
""" Number of data points """
self._data_dim = None
""" Dimension of data """
self._distances_np = None
self._distances_jax = None
""" Distances in the kNN graph """
self._indices_np = None
self._indices_jax = None
""" Neighbor indices in the kNN graph """
self._nearest_neighbor_distances = None
""" Neighbor indices in the kNN graph, excluding the point itself """
self._row_idx = None
""" Row indices for nearest neighbors """
self._col_idx = None
""" Column indices for nearest neighbors """
self._adjacency = None
""" kNN adjacency matrix """
self.random_state = random_state
self.batch_size = batch_size
#
if my_logger is None:
logger.remove()
sink = sys.stdout if verbose else open(os.devnull, "w", encoding="utf-8")
logger.add(sink, level="INFO")
self.logger = logger
""" System logger """
else:
self.logger = my_logger
# Memory manager to be adjusted for each particular type of hardware
# Below are some minimalist settings that may give less than satisfactory performance
self.memm = {"gpu": 16384, "tpu": 8192, "other": 8192} if memm is None else memm
# Using Mixed Precision Arithmetic flag (True = MPA, False = no MPA)
self.mpa = mpa
#
# Fitting the distribution kernel with given min_dist and spread
#
[docs]
def find_ab_params(self, min_dist=0.01, spread=1.0):
"""
Rational function approximation to the probabilistic t-kernel
"""
#
self.logger.info("find_ab_params ...")
def curve(x, a, b):
return 1.0 / (1.0 + a * x ** (2 * b))
#
xv = np.linspace(0, spread * 3, 300)
yv = np.zeros(xv.shape)
yv[xv < min_dist] = 1.0
yv[xv >= min_dist] = np.exp(-(xv[xv >= min_dist] - min_dist) / spread)
params, _ = curve_fit(curve, xv, yv)
#
self.logger.info(f"a = {params[0]}, b = {params[1]}")
self.logger.info("find_ab_params done ...")
#
return params[0], params[1]
#
# Fitting on data
#
[docs]
def fit(self, X: np.ndarray, y=None):
"""
Fit the model to data: create the kNN graph and fit the probability kernel to force layout parameters.
Parameters
----------
X: (numpy.ndarray)
High-dimensional data to fit the model. Shape (n_samples, n_features).
y: None
Ignored, exists for sklearn compatibility.
Returns
-------
self: The DiRe instance fitted to data.
"""
self.fit_transform(X, y)
return self
#
# Computing the kNN adjacency matrix (sparse)
#
[docs]
def make_knn_adjacency(self, batch_size=None):
"""
Internal routine building the adjacency matrix for the kNN graph.
This method computes the k-nearest neighbors for each point in the dataset
and constructs a sparse adjacency matrix representing the kNN graph.
It attempts to use GPU acceleration if available, with a fallback to CPU.
For large datasets, it uses batching to limit memory usage.
Parameters
----------
batch_size : int or None, optional
Number of samples to process at once. If None, a suitable value
will be automatically determined based on dataset size.
The method sets the following instance attributes:
- distances: Distances to the k nearest neighbors (including self)
- indices: Indices of the k nearest neighbors (including self)
- nearest_neighbor_distances: Distances to the nearest neighbors (excluding self)
- row_idx, col_idx: Indices for constructing the sparse adjacency matrix
- adjacency: Sparse adjacency matrix of the kNN graph
"""
self.logger.info("make_knn_adjacency ...")
# Ensure data is in the right format for HPIndex
self._data = np.ascontiguousarray(self._data.astype(np.float64))
n_neighbors = self.n_neighbors + 1 # Including the point itself
# Determine appropriate batch size for memory efficiency
if batch_size is None:
# Process in chunks to reduce peak memory usage
if jax.devices()[0].platform == "tpu":
batch_size = min(self.memm["tpu"], self._n_samples)
elif jax.devices()[0].platform == "gpu":
batch_size = min(self.memm["gpu"], self._n_samples)
else:
batch_size = min(self.memm["other"], self._n_samples)
self.logger.info(f"Using batch size: {batch_size}")
self.logger.debug(
f"[KNN] Using precision: {'float32' if self.mpa else 'float64'}"
)
if self.mpa:
self._indices_jax, self._distances_jax = HPIndex.knn_tiled(
self._data,
self._data,
n_neighbors,
batch_size,
batch_size,
dtype=jnp.float32,
)
else:
self._indices_jax, self._distances_jax = HPIndex.knn_tiled(
self._data,
self._data,
n_neighbors,
batch_size,
batch_size,
dtype=jnp.float64,
)
# Wait until ready
self._indices_jax.block_until_ready()
self._distances_jax.block_until_ready()
# Store results in numpy
self._indices_np = device_get(self._indices_jax).astype(np.int64)
self._distances_np = device_get(self._distances_jax).astype(np.float64)
# Extract nearest neighbor distances (excluding self)
self._nearest_neighbor_distances = self._distances_np[:, 1:]
# Create indices for sparse matrix construction
self._row_idx = np.repeat(np.arange(self._n_samples), n_neighbors)
self._col_idx = self._indices_np.ravel()
# Create sparse adjacency matrix (memory efficient)
data_values = self._distances_np.ravel()
self._adjacency = csr_matrix(
(data_values, (self._row_idx, self._col_idx)),
shape=(self._n_samples, self._n_samples),
)
# Clean up resources
gc.collect()
self.logger.info("make_knn_adjacency done ...")
#
# Initialize embedding using different techniques
#
[docs]
def do_pca_embedding(self):
"""
Initialize embedding using Principal Component Analysis (PCA).
This method creates an initial embedding of the data using PCA, which finds
a linear projection of the high-dimensional data into a lower-dimensional space
that maximizes the variance. If a kernel is specified, Kernel PCA is used instead,
which can capture nonlinear relationships.
Sets the init_embedding attribute with the PCA projection of the data.
"""
self.logger.info("do_pca_embedding ...")
if self.pca_kernel is not None:
# Use Kernel PCA for nonlinear dimensionality reduction
self.logger.info("Using kernelized PCA embedding...")
pca = KernelPCA(n_components=self.n_components, kernel=self.pca_kernel)
self._init_embedding = pca.fit_transform(self._data)
else:
# Use standard PCA for linear dimensionality reduction
self.logger.info("Using standard PCA embedding...")
pca = PCA(n_components=self.n_components)
self._init_embedding = pca.fit_transform(self._data)
self.logger.info("do_pca_embedding done ...")
[docs]
def do_spectral_embedding(self):
"""
Initialize embedding using Spectral Embedding.
This method creates an initial embedding of the data using spectral embedding,
which is based on the eigenvectors of the graph Laplacian. It relies on the
kNN graph structure to find a lower-dimensional representation that preserves
local relationships.
If a similarity kernel is specified, it is applied to transform the distances
in the adjacency matrix before computing the Laplacian.
Sets the init_embedding attribute with the spectral embedding of the data.
"""
self.logger.info("do_spectral_embedding ...")
# Apply similarity kernel if provided
if self.sim_kernel is not None:
self.logger.info("Applying similarity kernel to adjacency matrix...")
# Transform distances using the similarity kernel
data_values = self.sim_kernel(self._adjacency.data)
# Create a new adjacency matrix with transformed values
adj_mat = csr_matrix(
(data_values, (self._row_idx, self._col_idx)),
shape=(self._n_samples, self._n_samples),
)
else:
adj_mat = self._adjacency
# Make the adjacency matrix symmetric by adding it to its transpose
symmetric_adj = adj_mat + adj_mat.T
# Compute the normalized Laplacian
lap = laplacian(symmetric_adj, normed=True)
# Find the k smallest eigenvectors (k = dimension + 1)
k = self.n_components + 1
_, eigenvectors = eigsh(lap, k, which="SM")
# Skip the first eigenvector (corresponds to constant function)
self._init_embedding = eigenvectors[:, 1:k]
self.logger.info("do_spectral_embedding done ...")
[docs]
def do_random_embedding(self):
"""
Initialize embedding using Random Projection.
This method creates an initial embedding of the data using random projection,
which is a simple and computationally efficient technique for dimensionality
reduction. It projects the data onto a randomly generated basis, providing
a good starting point for further optimization.
Random projection is supported by the Johnson-Lindenstrauss lemma, which
guarantees that the distances between points are approximately preserved
under certain conditions.
Sets the init_embedding attribute with the random projection of the data.
"""
self.logger.info("do_random_embedding ...")
# Create a random projection matrix
if self.random_state is None:
key = random.PRNGKey(randint(0, 1000))
else:
key = random.PRNGKey(self.random_state)
# Use appropriate dtype for MPA
compute_dtype = jnp.float32 if self.mpa else jnp.float64
rand_basis = random.normal(key, (self.n_components, self._data_dim), dtype=compute_dtype)
# Move data and projection matrix to device memory with consistent dtype
data_matrix = device_put(self._data.astype(compute_dtype))
rand_basis = device_put(rand_basis)
# Project data onto random basis with explicit precision control
self._init_embedding = jnp.dot(
data_matrix, rand_basis.T,
precision=jax.lax.Precision.DEFAULT if self.mpa else jax.lax.Precision.HIGHEST
)
self.logger.info("do_random_embedding done ...")
#
# Efficient sampling for force-directed layout
#
[docs]
def do_rand_sampling(self, key, arr, n_samples, n_dirs, neg_ratio):
"""
Sample points for force calculation using cached sampling kernel.
Uses a cached kernel to avoid recompilation issues with dynamic vmap.
"""
self.logger.info("do_rand_sampling ...")
arr_len = len(arr)
compute_dtype = jnp.float32 if hasattr(self, 'mpa') and self.mpa else jnp.float64
# Use cached sampling kernel
sampling_kernel = _get_sampling_kernel(
self.n_components, n_samples, n_dirs, neg_ratio, compute_dtype
)
sampled_indices = sampling_kernel(key, arr)
self.logger.info("do_rand_sampling done ...")
return sampled_indices
#
# Create layout using force-directed optimization
#
[docs]
def do_layout(self, large_dataset_mode=None, force_cpu=False):
"""
Optimize the layout using force-directed placement with JAX kernels.
This method takes the initial embedding and iteratively refines it using
attractive and repulsive forces to create a meaningful low-dimensional
representation of the high-dimensional data. The algorithm applies:
1. Attraction forces between points that are neighbors in the high-dimensional space
2. Repulsion forces between randomly sampled points in the low-dimensional space
3. Gradual cooling (decreasing force impact) as iterations progress
The final layout is normalized to have zero mean and unit standard deviation.
Parameters
----------
large_dataset_mode : bool or None, optional
If True, use memory-efficient techniques for large datasets.
If None, automatically determine based on dataset size.
force_cpu : bool, optional
If True, force computations on CPU instead of GPU, which can
be helpful for large datasets that exceed GPU memory.
"""
self.logger.info("do_layout ...")
# Setup parameters with appropriate dtype
compute_dtype = jnp.float32 if self.mpa else jnp.float64
# Handle automatic batch size calculation if needed
sample_size = self.sample_size
num_iterations = self.max_iter_layout
if sample_size == "auto":
# Scale batch size based on dataset size and neighborhood size
sample_size = int(self.n_neighbors * np.log(self._n_samples))
# Ensure batch size is reasonable (not too small or large)
sample_size = max(min(512, sample_size), 32)
# Determine if we should use memory-efficient mode for large datasets
if large_dataset_mode is None:
large_dataset_mode = (self._n_samples > 65536) or (
jax.devices()[0].platform == "tpu"
)
# Other parameters
n_dirs = self.n_sample_dirs
neg_ratio = self.neg_ratio
# Debug initial embedding precision
self.logger.debug(
f"[LAYOUT] Initial embedding precision: {self._init_embedding.dtype}"
)
# Initialize and normalize positions with appropriate dtype
if force_cpu:
self.logger.info("Forcing computations on CPU")
cpu_device = jax.devices("cpu")[0]
init_pos_jax = device_put(self._init_embedding.astype(compute_dtype), device=cpu_device)
neighbor_indices_jax = device_put(self._indices_np, device=cpu_device)
else:
init_pos_jax = device_put(self._init_embedding.astype(compute_dtype))
neighbor_indices_jax = device_put(self._indices_jax)
# Use dtype-consistent operations for normalization
mean_pos = init_pos_jax.mean(axis=0, keepdims=True)
init_pos_jax = init_pos_jax - mean_pos # Center positions
std_pos = init_pos_jax.std(axis=0, keepdims=True)
# Avoid division by zero with appropriate epsilon for dtype
eps = jnp.array(1e-7 if self.mpa else 1e-15, dtype=compute_dtype)
init_pos_jax = init_pos_jax / jnp.maximum(std_pos, eps) # Normalize variance
# Set random seed for reproducibility
if self.random_state is None:
key = random.PRNGKey(randint(0, 1000))
else:
key = random.PRNGKey(self.random_state)
# Use cached layout optimization kernel - this replaces the entire Python loop
# with a single JIT-compiled operation that avoids all recompilation issues
self.logger.info(f"Using cached layout optimization kernel for {num_iterations} iterations")
# Use a reasonable max_iterations to allow caching across different iteration counts
# This enables a single kernel to handle 8, 16, 32, 64, 128 iterations efficiently
max_iterations = max(128, num_iterations)
# Calculate chunk size for large dataset mode
chunk_size = None
if large_dataset_mode:
platform = jax.devices()[0].platform
base_chunk_size = self.memm.get(platform, self.memm.get("other", 8192))
# Reduce chunk size significantly for large dataset mode
chunk_size = max(base_chunk_size // 4, 1024) # At least 1024, but 1/4 of normal batch size
self.logger.info(f"Using large dataset mode with chunk size: {chunk_size} (platform: {platform})")
else:
self.logger.info("Using standard dataset mode (no chunking)")
layout_kernel = _get_layout_optimization_kernel(
max_iterations=max_iterations,
sample_size=sample_size,
n_dirs=n_dirs,
neg_ratio=neg_ratio,
cutoff=self.cutoff,
a=self._a,
b=self._b,
use_float32=self.mpa,
large_dataset_mode=large_dataset_mode,
chunk_size=chunk_size
)
# Run the complete optimization in a single JIT-compiled call with dynamic iterations
final_positions = layout_kernel(init_pos_jax, neighbor_indices_jax, key, num_iterations)
# Store final layout
self._layout = np.asarray(final_positions)
# Clear any cached values to free memory
gc.collect()
self.logger.info("do_layout done ...")
# Modified _compute_forces method to use the kernel
def _compute_forces(
self, positions, chunk_indices, neighbor_indices, sample_indices, alpha=1.0
):
"""
Compute attractive and repulsive forces for points using JAX kernels.
This method uses JAX-optimized kernels to efficiently compute forces
between points during layout optimization.
Parameters
----------
positions : jax.numpy.ndarray
Current point positions
chunk_indices : jax.numpy.ndarray
Current batch indices
neighbor_indices : jax.numpy.ndarray
Indices of neighbors for attractive forces
sample_indices : jax.numpy.ndarray
Indices of points for repulsive forces
alpha : float
Cooling factor that scales force magnitude
Returns
-------
jax.numpy.ndarray
Net force vectors for each point
"""
self.logger.debug(f"[FORCE] Computing forces on device: {positions.device}")
self.logger.debug(f"[FORCE] Using MPA: {self.mpa}")
# Call the JAX-optimized kernel with MPA flag
return compute_forces_kernel(
positions,
chunk_indices,
neighbor_indices,
sample_indices,
alpha,
self._a,
self._b,
use_float32=self.mpa,
)
#
# Visualize the layout
#
[docs]
def visualize(
self,
labels=None,
point_size=2,
title=None,
colormap=None,
width=800,
height=600,
opacity=0.7,
):
"""
Generate an interactive visualization of the data in the transformed space.
This method creates a scatter plot visualization of the embedded data, supporting
both 2D and 3D visualizations depending on the specified dimension. Points can be
colored by provided labels for clearer visualization of clusters or categories.
Parameters
----------
labels : numpy.ndarray or None, optional
Labels for each data point to color the points in the visualization.
If None, all points will have the same color. Default is None.
point_size : int or float, optional
Size of points in the scatter plot. Default is 2.
title : str or None, optional
Title for the visualization. If None, a default title will be used. Default is None.
colormap : str or None, optional
Name of the colormap to use for labels (e.g., 'viridis', 'plasma').
If None, the default Plotly colormap will be used. Default is None.
width : int, optional
Width of the figure in pixels. Default is 800.
height : int, optional
Height of the figure in pixels. Default is 600.
opacity : float, optional
Opacity of the points (0.0 to 1.0). Default is 0.7.
Returns
-------
plotly.graph_objs._figure.Figure or None
A Plotly figure object if the visualization is successful;
None if no layout is available or dimension > 3.
Notes
-----
For 3D visualizations, you can rotate, zoom, and pan the plot interactively.
For both 2D and 3D, hover over points to see their coordinates and labels.
"""
# Check if layout is available
if self._layout is None:
self.logger.warning("visualize ERROR: no layout available")
return None
# Set default title if not provided
if title is None:
title = (
f"{self.init.capitalize()} Initialized {self.n_components}D Embedding"
)
# Common visualization parameters
vis_params = {
"color": "label" if labels is not None else None,
"color_continuous_scale": colormap,
"opacity": opacity,
"title": title,
"hover_data": ["label"] if labels is not None else None,
}
# Create 2D visualization
if self.n_components == 2:
self.logger.info("visualize: 2D ...")
# Create dataframe for plotting
datadf = pd.DataFrame(self._layout, columns=["x", "y"])
# Add labels if provided
if labels is not None:
datadf["label"] = labels
# Create scatter plot
fig = px.scatter(datadf, x="x", y="y", **vis_params)
# Update layout
fig.update_layout(
width=width,
height=height,
xaxis_title="x",
yaxis_title="y",
)
# Create 3D visualization
elif self.n_components == 3:
self.logger.info("visualize: 3D ...")
# Create dataframe for plotting
datadf = pd.DataFrame(self._layout, columns=["x", "y", "z"])
# Add labels if provided
if labels is not None:
datadf["label"] = labels
# Create 3D scatter plot
fig = px.scatter_3d(datadf, x="x", y="y", z="z", **vis_params)
# Update layout
fig.update_layout(
width=width,
height=height,
scene={
"xaxis_title": "x",
"yaxis_title": "y",
"zaxis_title": "z",
},
)
# Return None for higher dimensions
else:
self.logger.warning("visualize ERROR: dimension > 3")
return None
# Update marker properties
fig.update_traces(marker={"size": point_size})
return fig
#
# Kernel for force-directed layout
#
# Cache force computation kernels to avoid recompilation
@functools.lru_cache(maxsize=8)
def _get_force_computation_kernel(a, b, use_float32):
"""Get or create a cached force computation kernel for the given parameters."""
@jax.jit
def force_kernel(positions, chunk_indices, neighbor_indices, sample_indices, alpha):
return _compute_forces_impl(positions, chunk_indices, neighbor_indices, sample_indices, alpha, a, b, use_float32)
return force_kernel
[docs]
def compute_forces_kernel(
positions, chunk_indices, neighbor_indices, sample_indices, alpha, a, b, use_float32=True
):
"""
Cached wrapper for force computation to avoid recompilation.
Gets or creates a cached kernel for the given a, b, use_float32 configuration.
"""
kernel = _get_force_computation_kernel(a, b, use_float32)
return kernel(positions, chunk_indices, neighbor_indices, sample_indices, alpha)
def _compute_forces_impl(
positions, chunk_indices, neighbor_indices, sample_indices, alpha, a, b, use_float32
):
"""
Fully vectorized JAX implementation of force computation.
This is the core computation that gets JIT-compiled and cached.
"""
# Convert to appropriate dtype at the beginning
compute_dtype = jnp.float32 if use_float32 else jnp.float64
positions = positions.astype(compute_dtype)
alpha = jnp.array(alpha, dtype=compute_dtype)
a = jnp.array(a, dtype=compute_dtype)
b = jnp.array(b, dtype=compute_dtype)
# Get current positions for processing
current_positions = positions[chunk_indices] # (N, D)
forces = jnp.zeros_like(current_positions, dtype=compute_dtype)
# ===== ATTRACTION FORCES (k-NN only) =====
# Get neighbor positions efficiently using advanced indexing
neighbor_positions = positions[neighbor_indices] # (N, k, D)
# Broadcast current positions for vectorized computation
current_pos_expanded = current_positions[:, None, :] # (N, 1, D)
# Compute differences and distances
att_diff = neighbor_positions - current_pos_expanded # (N, k, D)
# Use more stable distance computation for MPA
att_dist_sq = jnp.sum(att_diff * att_diff, axis=2, keepdims=True) # (N, k, 1)
att_dist = jnp.sqrt(att_dist_sq + jnp.array(1e-10, dtype=compute_dtype)) # (N, k, 1)
# Attraction coefficient: 1 / (1 + a * (1/d)^(2b))
# Use more numerically stable computation
inv_dist = jnp.array(1.0, dtype=compute_dtype) / att_dist
att_coeff = jnp.array(1.0, dtype=compute_dtype) / (jnp.array(1.0, dtype=compute_dtype) + a * jnp.power(inv_dist, 2.0 * b)) # (N, k, 1)
# Compute attraction forces and sum over neighbors
att_forces = jnp.sum(att_coeff * att_diff / att_dist, axis=1) # (N, D)
forces += att_forces
# ===== REPULSION FORCES (Random Sampling) =====
if sample_indices.size > 0:
# Get negative sample positions
sample_positions = positions[sample_indices] # (N, n_neg, D)
# Compute differences and distances
rep_diff = sample_positions - current_pos_expanded # (N, n_neg, D)
# Use more stable distance computation for MPA
rep_dist_sq = jnp.sum(rep_diff * rep_diff, axis=2, keepdims=True) # (N, n_neg, 1)
rep_dist = jnp.sqrt(rep_dist_sq + jnp.array(1e-10, dtype=compute_dtype)) # (N, n_neg, 1)
# Repulsion coefficient: -1 / (1 + a * d^(2b))
rep_coeff = jnp.array(-1.0, dtype=compute_dtype) / (jnp.array(1.0, dtype=compute_dtype) + a * jnp.power(rep_dist, 2.0 * b)) # (N, n_neg, 1)
# Compute repulsion forces and sum over negative samples
rep_forces = jnp.sum(rep_coeff * rep_diff / rep_dist, axis=1) # (N, D)
forces += rep_forces
# Apply cooling factor
return alpha * forces
#
# Auxiliary functions for force-directed layout
#
# Cache coefficient kernels to avoid recompilation
@functools.lru_cache(maxsize=8)
def _get_distribution_kernel(a, b):
"""Get or create a cached distribution kernel for the given parameters."""
@jax.jit
def dist_kernel(dist):
return 1.0 / (1.0 + a * dist ** (2 * b))
return dist_kernel
[docs]
def distribution_kernel(dist, a, b):
"""
Probability kernel that maps distances to similarity scores using cached kernels.
This is a rational function approximation of a t-distribution.
Parameters
----------
dist : jax.numpy.ndarray
Distance values
a : float
Scale parameter that controls the steepness of the distribution
b : float
Shape parameter that controls the tail behavior
Returns
-------
float or jax.numpy.ndarray
Similarity score(s) between 0 and 1
"""
kernel = _get_distribution_kernel(a, b)
return kernel(dist)
# Cache coefficient kernels
@functools.lru_cache(maxsize=8)
def _get_coeff_att_kernel(a, b):
"""Get or create a cached attraction coefficient kernel."""
@jax.jit
def att_kernel(dist):
return 1.0 / (1.0 + a * (1.0 / dist) ** (2 * b))
return att_kernel
@functools.lru_cache(maxsize=8)
def _get_coeff_rep_kernel(a, b):
"""Get or create a cached repulsion coefficient kernel."""
@jax.jit
def rep_kernel(dist):
return -1.0 / (1.0 + a * dist ** (2 * b))
return rep_kernel
[docs]
def jax_coeff_att(dist, a, b):
"""JAX-optimized attraction coefficient function with caching."""
kernel = _get_coeff_att_kernel(a, b)
return kernel(dist)
[docs]
def jax_coeff_rep(dist, a, b):
"""JAX-optimized repulsion coefficient function with caching."""
kernel = _get_coeff_rep_kernel(a, b)
return kernel(dist)
# Cache random direction kernels to avoid recompilation
@functools.lru_cache(maxsize=8)
def _get_rand_directions_kernel(dim, num, dtype):
"""Get or create a cached random directions kernel for the given configuration."""
@jax.jit
def rand_dir_kernel(key):
points = random.normal(key, (num, dim), dtype=dtype)
norms = jnp.sqrt(jnp.sum(points * points, axis=-1, keepdims=True))
# Add small epsilon to avoid division by zero
eps = jnp.array(1e-7 if dtype == jnp.float32 else 1e-15, dtype=dtype)
return points / jnp.maximum(norms, eps)
return rand_dir_kernel
[docs]
def rand_directions(key, dim=2, num=100, dtype=jnp.float64):
"""
Sample unit vectors in random directions using cached kernels.
Parameters
----------
key : jax.random.PRNGKey
Random number generator key
dim : int
Dimensionality of the vectors
num : int
Number of random directions to sample
dtype : jnp.dtype
Data type for computation
Returns
-------
jax.numpy.ndarray
Array of shape (num, dim) containing unit vectors
"""
kernel = _get_rand_directions_kernel(dim, num, dtype)
return kernel(key)
# Cache slice kernels to avoid recompilation
@functools.lru_cache(maxsize=16)
def _get_slice_kernel(k):
"""Get or create a cached slice kernel for the given slice size."""
@jax.jit
def slice_kernel(arr, i):
return lax.dynamic_slice(arr, (i - k // 2,), (k,))
return slice_kernel
[docs]
def get_slice(arr, k, i):
"""
Extract a slice of size k centered around index i using cached kernels.
Parameters
----------
arr : jax.numpy.ndarray
Input array
k : int
Size of the slice
i : int
Center index position
Returns
-------
jax.numpy.ndarray
Slice of the input array
"""
kernel = _get_slice_kernel(k)
return kernel(arr, i)
# Cache sampling kernels to avoid recompilation in do_rand_sampling
@functools.lru_cache(maxsize=16)
def _get_sampling_kernel(n_components, n_samples, n_dirs, neg_ratio, dtype):
"""Get or create a cached sampling kernel for the given configuration."""
@jax.jit
def sampling_kernel(key, arr):
arr_len = arr.shape[0]
sampled_indices_list = []
# Get random unit vectors for projections
key, subkey = random.split(key)
direction_vectors = _get_rand_directions_kernel(n_components, n_dirs, dtype)(subkey)
# Get slice kernel for this n_samples
slice_kernel = _get_slice_kernel(n_samples)
# For each direction, sample points based on projections
for i in range(n_dirs):
vec = direction_vectors[i]
# Project points onto the direction vector
arr_proj = jnp.dot(vec, arr.T)
# Sort indices by projection values
indices_sort = jnp.argsort(arr_proj)
# For each point, take n_samples points around it in sorted order
# Use vmap on the cached slice kernel
vmap_slice = vmap(slice_kernel, in_axes=(None, 0))
indices = vmap_slice(indices_sort, jnp.arange(arr_len))
# Reorder indices back to original ordering
indices = indices[indices_sort]
# Add to list of sampled indices
sampled_indices_list.append(indices)
# Generate random negative samples for repulsion
n_neg_samples = int(neg_ratio * n_samples)
key, subkey = random.split(key)
neg_indices = random.randint(subkey, (arr_len, n_neg_samples), 0, arr_len)
sampled_indices_list.append(neg_indices)
# Combine all sampled indices
sampled_indices = jnp.concatenate(sampled_indices_list, axis=-1)
return sampled_indices
return sampling_kernel
def _get_chunked_sampling_kernel(n_components, n_samples, n_dirs, neg_ratio, dtype, chunk_size):
"""Get or create a memory-efficient chunked sampling kernel for large datasets."""
@jax.jit
def chunked_sampling_kernel(key, arr):
arr_len = arr.shape[0]
# For large datasets, use a simpler approximation that reduces memory usage
# Instead of processing in dynamic chunks, we use subsampling to reduce the problem size
# Get random unit vectors for projections
key, subkey = random.split(key)
direction_vectors = _get_rand_directions_kernel(n_components, n_dirs, dtype)(subkey)
# Subsample the array to reduce memory footprint
subsample_size = min(chunk_size, arr_len)
key, subkey = random.split(key)
subsample_indices = random.choice(subkey, arr_len, (subsample_size,), replace=False)
subsample_arr = arr[subsample_indices]
sampled_indices_list = []
# For each direction, sample points based on projections from subsampled array
for i in range(n_dirs):
vec = direction_vectors[i]
# Project subsampled points onto the direction vector
subsample_proj = jnp.dot(vec, subsample_arr.T)
# Sort indices by projection values
indices_sort = jnp.argsort(subsample_proj)
# For each original point, find nearest points in the subsampled space
arr_proj_full = jnp.dot(vec, arr.T)
def find_neighbors_for_point(point_proj):
# Find position in sorted subsampled projections
pos = jnp.searchsorted(subsample_proj[indices_sort], point_proj)
# Sample around this position in the subsampled space
start = jnp.maximum(0, pos - n_samples // 2)
end = jnp.minimum(subsample_size, start + n_samples)
start = jnp.maximum(0, end - n_samples)
# Map back to original indices using dynamic_slice
subsampled_neighbors = jax.lax.dynamic_slice(indices_sort, (start,), (n_samples,))
return subsample_indices[subsampled_neighbors]
vmap_find_neighbors = vmap(find_neighbors_for_point)
indices = vmap_find_neighbors(arr_proj_full)
sampled_indices_list.append(indices)
# Generate random negative samples
n_neg_samples = int(neg_ratio * n_samples)
key, subkey = random.split(key)
neg_indices = random.randint(subkey, (arr_len, n_neg_samples), 0, arr_len)
sampled_indices_list.append(neg_indices)
# Combine all sampled indices
sampled_indices = jnp.concatenate(sampled_indices_list, axis=-1)
return sampled_indices
return chunked_sampling_kernel
# Cache complete layout optimization kernels with max iterations
@functools.lru_cache(maxsize=8)
def _get_layout_optimization_kernel(max_iterations, sample_size, n_dirs, neg_ratio, cutoff, a, b, use_float32, large_dataset_mode=False, chunk_size=None):
"""Get or create a cached layout optimization kernel for the given configuration."""
@jax.jit
def layout_kernel(init_positions, neighbor_indices, initial_key, actual_iterations):
"""
Complete JIT-compiled layout optimization loop with dynamic iteration count.
This replaces the entire Python for-loop with a JAX-optimized implementation.
"""
n_components = init_positions.shape[1]
compute_dtype = jnp.float32 if use_float32 else jnp.float64
# Ensure consistent dtypes
positions = init_positions.astype(compute_dtype)
cutoff_typed = jnp.array(cutoff, dtype=compute_dtype)
a_typed = jnp.array(a, dtype=compute_dtype)
b_typed = jnp.array(b, dtype=compute_dtype)
# Get cached kernels - use chunked sampling for large datasets to reduce memory usage
if large_dataset_mode and chunk_size is not None:
sampling_kernel = _get_chunked_sampling_kernel(
n_components, sample_size, n_dirs, neg_ratio, compute_dtype, chunk_size
)
else:
sampling_kernel = _get_sampling_kernel(n_components, sample_size, n_dirs, neg_ratio, compute_dtype)
force_kernel = _get_force_computation_kernel(a, b, use_float32)
def optimization_step(carry, iter_id):
current_positions, key = carry
# Split key for this iteration
key, subkey = random.split(key)
# Skip iterations beyond actual_iterations
skip_iteration = iter_id >= actual_iterations
# Calculate cooling factor (use actual_iterations for proper scaling)
alpha = jnp.where(skip_iteration, 0.0, 1.0 - iter_id / actual_iterations)
# Sample points for force calculation
sample_indices = sampling_kernel(subkey, current_positions)
# Conditionally compute forces and update positions using lax.cond
def do_update(_):
n_points = current_positions.shape[0]
chunk_indices = jnp.arange(n_points)
net_force = force_kernel(
current_positions, chunk_indices, neighbor_indices, sample_indices, alpha
)
net_force = jnp.clip(net_force, -cutoff_typed, cutoff_typed)
return current_positions + net_force
def skip_update(_):
return current_positions
new_positions = lax.cond(skip_iteration, skip_update, do_update, None)
return (new_positions, key), None
# Run the optimization loop using scan for efficiency with max iterations
(final_positions, _), _ = lax.scan(
optimization_step,
(positions, initial_key),
jnp.arange(max_iterations)
)
# Final normalization
mean_final = final_positions.mean(axis=0, keepdims=True)
final_positions = final_positions - mean_final
std_final = final_positions.std(axis=0, keepdims=True)
eps = jnp.array(1e-7 if use_float32 else 1e-15, dtype=compute_dtype)
final_positions = final_positions / jnp.maximum(std_final, eps)
return final_positions
return layout_kernel