Source code for dire_jax.hpindex

# hpindex.py

"""
A JAX-based implementation for efficient k-nearest neighbors.
"""

from functools import partial, lru_cache
import jax
import jax.numpy as jnp

#
# Double precision support
#
jax.config.update("jax_enable_x64", True)

[docs] class HPIndex: """ A kernelized kNN index that uses batching / tiling to efficiently handle large datasets with limited memory usage. """ def __init__(self): pass
[docs] @staticmethod def knn_tiled(x, y, k=5, x_tile_size=8192, y_batch_size=1024, dtype=jnp.float64): """ Single-kernel kNN implementation that compiles once and reuses efficiently. Uses a single JIT-compiled kernel with fixed tile/batch parameters. Args: x: (n, d) array of database points y: (m, d) array of query points k: number of nearest neighbors x_tile_size: size of database tiles y_batch_size: size of query batches dtype: desired floating-point dtype (e.g., jnp.float32 or jnp.float64) Returns: (m, k) array of indices and distances of nearest neighbors """ # Get or compile the kernel for this configuration kernel = HPIndex._get_knn_kernel(k, x_tile_size, y_batch_size, dtype) # Call the compiled kernel return kernel(x, y)
@staticmethod @lru_cache(maxsize=16) # Cache compiled kernels for different configurations def _get_knn_kernel(k, x_tile_size, y_batch_size, dtype): """ Get or create a cached JIT-compiled kNN kernel for the given configuration. This ensures we reuse compiled kernels across different datasets with same params. """ @jax.jit def knn_kernel(x, y): # Ensure consistent dtypes x = x.astype(dtype) y = y.astype(dtype) n_x, d_x = x.shape n_y, d_y = y.shape # Pad data to tile/batch boundaries padded_n_x = ((n_x + x_tile_size - 1) // x_tile_size) * x_tile_size padded_n_y = ((n_y + y_batch_size - 1) // y_batch_size) * y_batch_size # Pad x if needed if padded_n_x > n_x: x_pad = jnp.full((padded_n_x - n_x, d_x), jnp.finfo(dtype).max / 2, dtype=dtype) x_padded = jnp.concatenate([x, x_pad], axis=0) else: x_padded = x # Pad y if needed if padded_n_y > n_y: y_pad = jnp.zeros((padded_n_y - n_y, d_y), dtype=dtype) y_padded = jnp.concatenate([y, y_pad], axis=0) else: y_padded = y # Calculate number of tiles/batches num_y_batches = padded_n_y // y_batch_size num_x_tiles = padded_n_x // x_tile_size # Initialize results all_indices = jnp.zeros((padded_n_y, k), dtype=jnp.int64) all_distances = jnp.ones((padded_n_y, k), dtype=dtype) * jnp.finfo(dtype).max # Get distance kernel for this dtype distance_kernel = _get_distance_kernel(dtype) # Main processing loop using scan for efficiency def process_y_batch(carry, y_batch_idx): curr_indices, curr_distances = carry y_start = y_batch_idx * y_batch_size y_batch = jax.lax.dynamic_slice(y_padded, (y_start, 0), (y_batch_size, d_y)) batch_indices = jnp.zeros((y_batch_size, k), dtype=jnp.int64) batch_distances = jnp.ones((y_batch_size, k), dtype=dtype) * jnp.finfo(dtype).max def process_x_tile(tile_carry, x_tile_idx): batch_idx, batch_dist = tile_carry x_start = x_tile_idx * x_tile_size x_tile = jax.lax.dynamic_slice(x_padded, (x_start, 0), (x_tile_size, d_x)) # Compute distances tile_distances = distance_kernel(y_batch, x_tile) # Create tile indices tile_indices = jnp.arange(x_tile_size) + x_start tile_indices = jnp.broadcast_to(tile_indices, tile_distances.shape) # Merge and get top k combined_distances = jnp.concatenate([batch_dist, tile_distances], axis=1) combined_indices = jnp.concatenate([batch_idx, tile_indices], axis=1) top_k_idx = jnp.argsort(combined_distances)[:, :k] new_batch_dist = jnp.take_along_axis(combined_distances, top_k_idx, axis=1) new_batch_idx = jnp.take_along_axis(combined_indices, top_k_idx, axis=1) return (new_batch_idx, new_batch_dist), None # Process all x tiles for this y batch (batch_indices, batch_distances), _ = jax.lax.scan( process_x_tile, (batch_indices, batch_distances), jnp.arange(num_x_tiles) ) # Update results curr_indices = jax.lax.dynamic_update_slice(curr_indices, batch_indices, (y_start, 0)) curr_distances = jax.lax.dynamic_update_slice(curr_distances, batch_distances, (y_start, 0)) return (curr_indices, curr_distances), None # Process all y batches (all_indices, all_distances), _ = jax.lax.scan( process_y_batch, (all_indices, all_distances), jnp.arange(num_y_batches) ) # Return only valid portion return all_indices[:n_y], all_distances[:n_y] return knn_kernel
# Globally define the _compute_batch_distances_l2 function for reuse # Using lru_cache to avoid recompilation for different dtype combinations @lru_cache(maxsize=8) # Cache different dtype combinations def _get_distance_kernel(dtype): """Get or create a JIT-compiled distance kernel for the given dtype.""" @jax.jit def compute_distances(y_batch, x): # Ensure consistent dtype y_batch = y_batch.astype(dtype) x = x.astype(dtype) # Compute squared norms using more numerically stable method x_norm = jnp.sum(x * x, axis=1) y_norm = jnp.sum(y_batch * y_batch, axis=1) # Compute xy term with explicit dtype xy = jnp.dot(y_batch, x.T, precision=jax.lax.Precision.DEFAULT) # Complete squared distance: ||y||² + ||x||² - 2*<y,x> # Use broadcasting with consistent dtype two = jnp.array(2.0, dtype=dtype) dists2 = y_norm[:, jnp.newaxis] + x_norm[jnp.newaxis, :] - two * xy # Clip to valid range for the dtype zero = jnp.array(0.0, dtype=dtype) dists2 = jnp.maximum(dists2, zero) return dists2 return compute_distances def _compute_batch_distances_l2(y_batch, x, dtype=jnp.float64): """ Compute the squared L2 distances between a batch of query points and all database points. Uses cached kernels to avoid recompilation. Args: y_batch: (batch_size, d) array of query points x: (n, d) array of database points dtype: data type for computation Returns: (batch_size, n) array of squared distances """ # Get the cached kernel for this dtype distance_kernel = _get_distance_kernel(dtype) return distance_kernel(y_batch, x)