Source code for dire_jax.hpindex

# hpindex.py

"""
A JAX-based implementation for efficient k-nearest neighbors.
"""

from functools import partial
import jax
import jax.numpy as jnp

#
# Double precision support
#
jax.config.update("jax_enable_x64", True)

[docs] class HPIndex: """ A kernelized kNN index that uses batching / tiling to efficiently handle large datasets with limited memory usage. """ def __init__(self): pass
[docs] @staticmethod def knn_tiled(x, y, k=5, x_tile_size=8192, y_batch_size=1024, dtype=jnp.float64): """ Advanced implementation that tiles both database and query points. This wrapper handles the dynamic aspects before calling the JIT-compiled function. Args: x: (n, d) array of database points y: (m, d) array of query points k: number of nearest neighbors x_tile_size: size of database tiles y_batch_size: size of query batches dtype: desired floating-point dtype (e.g., jnp.float32 or jnp.float64) Returns: (m, k) array of indices of nearest neighbors """ x = x.astype(dtype) y = y.astype(dtype) n_x, _ = x.shape n_y, _ = y.shape # Ensure batch sizes aren't larger than the data dimensions x_tile_size = min(x_tile_size, n_x) y_batch_size = min(y_batch_size, n_y) # Calculate batching parameters num_y_batches = n_y // y_batch_size y_remainder = n_y % y_batch_size num_x_tiles = (n_x + x_tile_size - 1) // x_tile_size # Call the JIT-compiled implementation with concrete values return HPIndex._knn_tiled_jit( x, y, k, x_tile_size, y_batch_size, num_y_batches, y_remainder, num_x_tiles, n_x )
@staticmethod @partial(jax.jit, static_argnums=(2, 3, 4, 5, 6, 7, 8, 9)) def _knn_tiled_jit(x, y, k, x_tile_size, y_batch_size, num_y_batches, y_remainder, num_x_tiles, n_x, dtype=jnp.float64): """ JIT-compiled implementation of tiled KNN with concrete batch parameters. """ n_y, d_y = y.shape _, d_x = x.shape # Initialize results all_indices = jnp.zeros((n_y, k), dtype=jnp.int64) all_distances = jnp.ones((n_y, k), dtype=dtype) * jnp.finfo(dtype).max # Define the scan function for processing y batches def process_y_batch(carry, y_batch_idx): curr_indices, curr_distances = carry # Get current batch of query points y_start = y_batch_idx * y_batch_size y_batch = jax.lax.dynamic_slice(y, (y_start, 0), (y_batch_size, d_y)) # Initialize batch results batch_indices = jnp.zeros((y_batch_size, k), dtype=jnp.int64) batch_distances = jnp.ones((y_batch_size, k), dtype=dtype) * jnp.finfo(dtype).max # Define the scan function for processing x tiles within a y batch def process_x_tile(carry, x_tile_idx): batch_idx, batch_dist = carry # Get current tile of database points - use fixed size slices x_start = x_tile_idx * x_tile_size # Use a fixed size for the slice and then mask invalid values x_tile = jax.lax.dynamic_slice( x, (x_start, 0), (x_tile_size, d_x) ) # Calculate how many elements are actually valid # (This is now done without dynamic shapes) x_tile_actual_size = jnp.minimum(x_tile_size, n_x - x_start) # Compute distances between y_batch and x_tile tile_distances = _compute_batch_distances(y_batch, x_tile) # Mask out invalid indices (those beyond the actual data) valid_mask = jnp.arange(x_tile_size) < x_tile_actual_size tile_distances = jnp.where( valid_mask[jnp.newaxis, :], tile_distances, jnp.ones_like(tile_distances, dtype=dtype) * jnp.finfo(dtype).max ) # Adjust indices to account for tile offset # Make sure indices are within bounds tile_indices = jnp.minimum( jnp.arange(x_tile_size) + x_start, n_x - 1 # Ensure indices don't go beyond n_x ) tile_indices = jnp.broadcast_to(tile_indices, tile_distances.shape) # Merge current tile results with previous results combined_distances = jnp.concatenate([batch_dist, tile_distances], axis=1) combined_indices = jnp.concatenate([batch_idx, tile_indices], axis=1) # Sort and get top k top_k_idx = jnp.argsort(combined_distances)[:, :k] # Gather top k distances and indices new_batch_dist = jnp.take_along_axis(combined_distances, top_k_idx, axis=1) new_batch_idx = jnp.take_along_axis(combined_indices, top_k_idx, axis=1) return (new_batch_idx, new_batch_dist), None # Process all x tiles for this y batch (batch_indices, batch_distances), _ = jax.lax.scan( process_x_tile, (batch_indices, batch_distances), jnp.arange(num_x_tiles) ) # Update overall results for this batch curr_indices = jax.lax.dynamic_update_slice( curr_indices, batch_indices, (y_start, 0) ) curr_distances = jax.lax.dynamic_update_slice( curr_distances, batch_distances, (y_start, 0) ) return (curr_indices, curr_distances), None # Process all full y batches (all_indices, all_distances), _ = jax.lax.scan( process_y_batch, (all_indices, all_distances), jnp.arange(num_y_batches) ) # Handle y remainder with similar changes if needed def handle_y_remainder(indices, distances): y_start = num_y_batches * y_batch_size # Get and pad remainder batch remainder_y = jax.lax.dynamic_slice(y, (y_start, 0), (y_remainder, d_y)) padded_y = jnp.pad(remainder_y, ((0, y_batch_size - y_remainder), (0, 0))) # Initialize remainder results remainder_indices = jnp.zeros((y_batch_size, k), dtype=jnp.int64) remainder_distances = jnp.ones((y_batch_size, k), dtype=dtype) * jnp.finfo(dtype).max # Process x tiles for the remainder batch (with same fix as above) def process_x_tile_remainder(carry, x_tile_idx): batch_idx, batch_dist = carry # Get current tile of database points - use fixed size slices x_start = x_tile_idx * x_tile_size # Use fixed size for the slice x_tile = jax.lax.dynamic_slice( x, (x_start, 0), (x_tile_size, d_x) ) # Calculate actual valid size x_tile_actual_size = jnp.minimum(x_tile_size, n_x - x_start) # Compute distances between padded_y and x_tile tile_distances = _compute_batch_distances(padded_y, x_tile) # Mask out invalid indices (both for y padding and x overflow) x_valid_mask = jnp.arange(x_tile_size) < x_tile_actual_size tile_distances = jnp.where( x_valid_mask[jnp.newaxis, :], tile_distances, jnp.ones_like(tile_distances, dtype=dtype) * jnp.finfo(dtype).max ) # Adjust indices to account for tile offset tile_indices = jnp.minimum( jnp.arange(x_tile_size) + x_start, n_x - 1 # Ensure indices don't go beyond n_x ) tile_indices = jnp.broadcast_to(tile_indices, tile_distances.shape) # Merge current tile results with previous results combined_distances = jnp.concatenate([batch_dist, tile_distances], axis=1) combined_indices = jnp.concatenate([batch_idx, tile_indices], axis=1) # Sort and get top k top_k_idx = jnp.argsort(combined_distances)[:, :k] # Gather top k distances and indices new_batch_dist = jnp.take_along_axis(combined_distances, top_k_idx, axis=1) new_batch_idx = jnp.take_along_axis(combined_indices, top_k_idx, axis=1) return (new_batch_idx, new_batch_dist), None # Process all x tiles for the remainder batch (remainder_indices, remainder_distances), _ = jax.lax.scan( process_x_tile_remainder, (remainder_indices, remainder_distances), jnp.arange(num_x_tiles) ) # Extract valid remainder results and update both arrays valid_i = remainder_indices[:y_remainder] valid_d = remainder_distances[:y_remainder] indices = jax.lax.dynamic_update_slice(indices, valid_i, (y_start, 0)) distances = jax.lax.dynamic_update_slice(distances, valid_d, (y_start, 0)) return indices, distances # Conditionally handle remainder to avoid issues with remainder=0 all_indices, all_distances = jax.lax.cond( y_remainder > 0, lambda args: handle_y_remainder(*args), lambda args: args, (all_indices, all_distances) ) return all_indices, all_distances
# Globally define the _compute_batch_distances function for reuse @partial(jax.jit, static_argnums=(2,)) def _compute_batch_distances(y_batch, x, dtype=jnp.float64): """ Compute the squared distances between a batch of query points and all database points. Args: y_batch: (batch_size, d) array of query points x: (n, d) array of database points Returns: (batch_size, n) array of squared distances """ # Compute squared norms x_norm = jnp.sum(x**2, axis=1) y_norm = jnp.sum(y_batch**2, axis=1) # Compute xy term xy = jnp.dot(y_batch, x.T) # Complete squared distance: ||y||² + ||x||² - 2*<y,x> dists2 = y_norm[:, jnp.newaxis] + x_norm[jnp.newaxis, :] - 2 * xy dists2 = jnp.clip(dists2, 0, jnp.finfo(dtype).max) return dists2