# dire_pytorch_memory_efficient.py
"""
Memory-efficient PyTorch/PyKeOps backend for DiRe.
This implementation inherits from DiRePyTorch and overrides specific methods for:
- FP16 support for memory-efficient k-NN computation
- Point-by-point attraction force computation to avoid large tensor materialization
- More aggressive memory management and cache clearing
- Optional PyKeOps LazyTensors for repulsion when available
"""
import gc
import numpy as np
import torch
from loguru import logger
# Import base class and compiled kernels
from .dire_pytorch import DiRePyTorch, _attraction_forces_compiled # pylint: disable=cyclic-import
# PyKeOps for efficient force computations
try:
from pykeops.torch import LazyTensor
PYKEOPS_AVAILABLE = True
except ImportError:
PYKEOPS_AVAILABLE = False
logger.warning("PyKeOps not available. Install with: pip install pykeops")
[docs]
class DiRePyTorchMemoryEfficient(DiRePyTorch):
"""
Memory-optimized PyTorch implementation of DiRe for large-scale datasets.
This class extends DiRePyTorch with enhanced memory management capabilities,
making it suitable for processing very large datasets that would otherwise
cause out-of-memory errors with the standard implementation.
Key Improvements over DiRePyTorch
---------------------------------
- **FP16 Support**: Uses half-precision by default for 2x memory reduction
- **Dynamic Chunking**: Automatically adjusts chunk sizes based on available memory
- **Aggressive Cleanup**: More frequent garbage collection and cache clearing
- **PyKeOps Integration**: Optional LazyTensors for memory-efficient exact repulsion
- **Memory Monitoring**: Real-time memory usage tracking and warnings
- **Point-wise Processing**: Falls back to point-by-point computation when needed
Best Use Cases
--------------
- Datasets with >100K points
- High-dimensional data (>500 features)
- Memory-constrained environments
- Production systems requiring reliable memory usage
Parameters
----------
*args
Positional arguments passed to DiRePyTorch parent class.
use_fp16 : bool, default=True
Enable FP16 precision for memory efficiency (recommended).
Provides 2x memory reduction and significant speed improvements.
use_pykeops_repulsion : bool, default=True
Use PyKeOps LazyTensors for repulsion when beneficial.
Automatically disabled if PyKeOps unavailable or dataset too large.
pykeops_threshold : int, default=50000
Maximum dataset size for PyKeOps all-pairs computation.
Above this threshold, random sampling is used instead.
memory_fraction : float, default=0.25
Fraction of available memory to use for computations.
Lower values are more conservative but may be slower.
**kwargs
Additional keyword arguments passed to DiRePyTorch parent class.
Includes: n_components, n_neighbors, init, max_iter_layout, min_dist,
spread, cutoff, neg_ratio, verbose, random_state, use_exact_repulsion,
metric (custom distance function for k-NN computation).
Examples
--------
Memory-efficient processing of large dataset::
from dire_rapids import DiRePyTorchMemoryEfficient
import numpy as np
# Large dataset
X = np.random.randn(500000, 512)
# Memory-efficient reducer
reducer = DiRePyTorchMemoryEfficient(
use_fp16=True,
memory_fraction=0.3,
verbose=True
)
embedding = reducer.fit_transform(X)
Custom memory settings::
reducer = DiRePyTorchMemoryEfficient(
use_pykeops_repulsion=False, # Disable PyKeOps
memory_fraction=0.15, # Use less memory
pykeops_threshold=20000 # Lower PyKeOps threshold
)
With custom distance metric::
# L1 metric for k-NN with memory efficiency
reducer = DiRePyTorchMemoryEfficient(
metric='(x - y).abs().sum(-1)',
n_neighbors=32,
use_fp16=True,
memory_fraction=0.2
)
embedding = reducer.fit_transform(X)
"""
[docs]
def __init__(
self,
*args,
use_fp16=True, # Enable FP16 by default for memory efficiency
use_pykeops_repulsion=True, # Use PyKeOps for repulsion when possible
pykeops_threshold=50000, # Max points for PyKeOps all-pairs
memory_fraction=0.25,
**kwargs
):
"""
Initialize memory-efficient DiRe reducer.
Parameters
----------
*args
Positional arguments passed to DiRePyTorch parent class.
use_fp16 : bool, default=True
Enable FP16 precision for memory efficiency. Provides 2x memory
reduction and significant speed improvements on modern GPUs.
use_pykeops_repulsion : bool, default=True
Use PyKeOps LazyTensors for memory-efficient repulsion computation
when dataset size is below pykeops_threshold.
pykeops_threshold : int, default=50000
Maximum dataset size for PyKeOps all-pairs computation.
Above this threshold, random sampling is used instead.
memory_fraction : float, default=0.25
Fraction of available memory to use for computations.
Lower values are more conservative but may be slower.
**kwargs
Additional keyword arguments passed to DiRePyTorch parent class.
See DiRePyTorch documentation for available parameters including:
- n_components, n_neighbors, init, max_iter_layout, min_dist, spread
- cutoff, neg_ratio, verbose, random_state, use_exact_repulsion
- metric: Custom distance metric for k-NN (str, callable, or None)
"""
# Call parent constructor
super().__init__(*args, **kwargs)
# Additional memory-efficient parameters
self.use_fp16 = use_fp16
self.use_pykeops_repulsion = use_pykeops_repulsion
self.pykeops_threshold = pykeops_threshold
self.memory_fraction = memory_fraction
# Log memory-efficient settings
if self.verbose:
self.logger.info("Memory-efficient mode enabled")
if self.use_fp16 and self.device.type == 'cuda':
self.logger.info("FP16 enabled for k-NN computation")
if self.use_pykeops_repulsion and PYKEOPS_AVAILABLE:
self.logger.info(f"PyKeOps repulsion enabled (threshold: {self.pykeops_threshold} points)")
def _get_available_memory(self):
"""
Get available system memory in bytes.
Private method that queries the available memory on the current device
(GPU or CPU) to inform memory-aware chunk sizing decisions.
Returns
-------
int
Available memory in bytes.
Notes
-----
Private method, should not be called directly. Used by _compute_optimal_chunk_size().
For CUDA devices, returns free GPU memory.
For CPU, returns available system RAM.
"""
if self.device.type == 'cuda':
return torch.cuda.mem_get_info()[0] # Free memory
import psutil # pylint: disable=import-outside-toplevel
return psutil.virtual_memory().available
def _compute_optimal_chunk_size(self, n_samples, n_features, operation_type="knn", dtype=torch.float32):
"""
Compute optimal chunk size based on available memory and operation type.
This private method dynamically calculates the optimal chunk size for different
operations based on available system memory, data characteristics, and the
configured memory fraction.
Parameters
----------
n_samples : int
Total number of samples in the dataset.
n_features : int
Number of features per sample.
operation_type : {'knn', 'repulsion', 'general'}, default='knn'
Type of operation to optimize for:
- 'knn': k-nearest neighbors computation
- 'repulsion': Repulsion force computation
- 'general': General tensor operations
dtype : torch.dtype, default=torch.float32
Data type for memory calculations.
Returns
-------
int
Optimal chunk size for the specified operation.
Notes
-----
Private method, should not be called directly. Used by _compute_knn() and _compute_forces().
The chunk size is bounded between reasonable minimum and maximum values
to ensure both memory safety and computational efficiency.
"""
available_memory = self._get_available_memory()
usable_memory = available_memory * self.memory_fraction
bytes_per_element = 2 if dtype == torch.float16 else 4
if operation_type == "knn":
# For k-NN: chunk_size × n_samples × bytes_per_element (distance matrix chunk)
max_chunk_size = int(usable_memory / (n_samples * bytes_per_element))
elif operation_type == "repulsion":
# For repulsion: chunk_size × n_neg × n_components × bytes_per_element
n_neg = min(int(self.neg_ratio * self.n_neighbors), n_samples - 1)
memory_per_sample = n_neg * self.n_components * bytes_per_element
max_chunk_size = int(usable_memory / memory_per_sample)
else: # general
# Conservative estimate: chunk_size × n_features × bytes_per_element
memory_per_sample = n_features * bytes_per_element
max_chunk_size = int(usable_memory / memory_per_sample)
# Apply reasonable bounds
min_chunk_size = 100
max_reasonable_chunk_size = min(20000, n_samples)
optimal_chunk_size = max(min_chunk_size, min(max_chunk_size, max_reasonable_chunk_size))
if self.verbose:
self.logger.debug(f"Memory-aware chunk sizing for {operation_type}: "
f"{optimal_chunk_size} (available: {available_memory/1e9:.1f}GB)")
return optimal_chunk_size
def _compute_knn(self, X, chunk_size=None, use_fp16=None):
"""
Compute k-nearest neighbors with enhanced memory management.
This method overrides the parent implementation with memory-aware chunk
sizing, automatic FP16 selection, and aggressive memory cleanup.
Parameters
----------
X : numpy.ndarray
Input data of shape (n_samples, n_features).
chunk_size : int, optional
Size of chunks for processing. If None, automatically computed
based on available memory.
use_fp16 : bool, optional
Use FP16 precision. If None, automatically determined based on
data size and GPU capabilities.
Notes
-----
Private method, should not be called directly. Used by fit_transform().
Enhancements over parent method:
- Automatic FP16 selection for large/high-dimensional datasets
- Memory-aware chunk size computation
- More aggressive memory cleanup after processing
Side Effects
------------
Sets self._knn_indices and self._knn_distances with computed k-NN graph.
"""
n_samples = X.shape[0]
n_dims = X.shape[1]
# Use instance setting if not explicitly provided
if use_fp16 is None:
use_fp16 = self.use_fp16
# Force FP16 for large/high-dimensional datasets on GPU
if self.device.type == 'cuda' and (n_dims >= 100 or n_samples >= 50000):
use_fp16 = True
self.logger.info(f"Forcing FP16 for large dataset ({n_samples} samples, {n_dims}D)")
# Compute optimal chunk size based on available memory
if chunk_size is None:
chunk_size = self._compute_optimal_chunk_size(
n_samples,
n_dims,
operation_type="knn",
dtype=torch.float16 if use_fp16 else torch.float32
)
self.logger.info(f"Memory-efficient k-NN: chunk_size={chunk_size}, FP16={use_fp16}")
# Call parent method with our settings
super()._compute_knn(X, chunk_size=chunk_size, use_fp16=use_fp16)
# Aggressive memory cleanup
if self.device.type == 'cuda':
torch.cuda.empty_cache()
gc.collect()
def _compute_forces(self, positions, iteration, max_iterations):
"""
Compute forces with memory-efficient strategies and PyKeOps integration.
This method overrides the parent force computation with enhanced memory
management and optional PyKeOps LazyTensors for exact repulsion.
Parameters
----------
positions : torch.Tensor
Current positions of points in embedding space, shape (n_samples, n_components).
iteration : int
Current iteration number (0-indexed).
max_iterations : int
Total number of iterations planned.
Returns
-------
torch.Tensor
Computed forces of shape (n_samples, n_components).
"""
n_samples = positions.shape[0]
forces = torch.zeros_like(positions)
# Auto-adjust chunk size based on available memory
chunk_size = self._compute_optimal_chunk_size(
n_samples,
self.n_components,
operation_type="repulsion",
dtype=torch.float16 if self.use_fp16 else torch.float32
)
# Parameters
a_val = float(self._a)
b_val = float(self._b)
# ============ ATTRACTION FORCES (compiled kernel) ============
if not hasattr(self, '_knn_indices_torch') or self._knn_indices_torch.device != positions.device:
self._knn_indices_torch = torch.as_tensor(self._knn_indices, dtype=torch.long, device=positions.device)
knn_indices_torch = self._knn_indices_torch
forces += _attraction_forces_compiled(positions, knn_indices_torch, a_val, b_val)
# ============ REPULSION FORCES ============
use_pykeops = (
PYKEOPS_AVAILABLE and
self.use_pykeops_repulsion and
n_samples < self.pykeops_threshold and
self.device.type == 'cuda' and
not self.use_exact_repulsion
)
if use_pykeops:
self.logger.debug("Using PyKeOps LazyTensors for repulsion")
X_i = LazyTensor(positions[:, None, :].contiguous()) # (N, 1, D)
X_j = LazyTensor(positions[None, :, :].contiguous()) # (1, N, D)
diff = X_j - X_i # (N, N, D) lazy
D_sq = (diff ** 2).sum(-1) # (N, N) lazy
D_sq = D_sq + 1e-10
D_sq_b = D_sq ** b_val
rep_kernel = -1.0 / (1.0 + a_val * D_sq_b)
D_ij = D_sq.sqrt()
cutoff_scale = (-D_ij / self.cutoff).exp()
rep_kernel = rep_kernel * cutoff_scale
force_dir = diff / D_ij
rep_forces = (rep_kernel * force_dir).sum(1)
forces += rep_forces
elif self.use_exact_repulsion:
self.logger.debug("Using exact all-pairs repulsion (memory intensive)")
# Delegate to parent which computes both attraction + repulsion;
# discard the attraction forces we already accumulated above.
return super()._compute_forces(positions, iteration, max_iterations)
else:
self.logger.debug("Using chunked random sampling for repulsion")
n_neg = min(int(self.neg_ratio * self.n_neighbors), n_samples - 1)
repulsion_chunk_size = min(chunk_size, n_samples)
for start_idx in range(0, n_samples, repulsion_chunk_size):
end_idx = min(start_idx + repulsion_chunk_size, n_samples)
chunk_size_actual = end_idx - start_idx
neg_indices = torch.randint(0, n_samples, (chunk_size_actual, n_neg), device=self.device)
chunk_indices = torch.arange(start_idx, end_idx, device=self.device).unsqueeze(1)
mask = neg_indices == chunk_indices
if mask.any():
replacement_base = torch.randint(0, n_samples, (chunk_size_actual, n_neg), device=self.device)
replacement_mask = replacement_base == chunk_indices
while replacement_mask.any():
replacement_base[replacement_mask] = torch.randint(0, n_samples,
(replacement_mask.sum(),),
device=self.device)
replacement_mask = replacement_base == chunk_indices
neg_indices = torch.where(mask, replacement_base, neg_indices)
chunk_positions = positions[start_idx:end_idx]
neg_positions = positions[neg_indices]
center_positions = chunk_positions.unsqueeze(1)
diff = neg_positions - center_positions
dist_sq = (diff * diff).sum(dim=2, keepdim=True) + 1e-10
inv_dist = torch.rsqrt(dist_sq)
dist = dist_sq * inv_dist
dist_sq_b = dist_sq ** b_val
rep_coeff = -1.0 / (1.0 + a_val * dist_sq_b)
cutoff_scale = torch.exp(-dist / self.cutoff)
rep_coeff = rep_coeff * cutoff_scale
chunk_repulsion_forces = (rep_coeff * diff * inv_dist).sum(dim=1)
forces[start_idx:end_idx] += chunk_repulsion_forces
# Clear intermediate tensors to free memory
del neg_indices, neg_positions, diff, dist, dist_sq, rep_coeff, cutoff_scale, chunk_repulsion_forces
forces = torch.clamp(forces, -self.cutoff, self.cutoff)
return forces
def _optimize_layout(self, initial_positions):
"""
Optimize embedding layout with enhanced memory monitoring and management.
This method overrides the parent optimization loop with real-time memory
monitoring, more frequent cleanup, and detailed progress reporting.
Parameters
----------
initial_positions : torch.Tensor
Initial embedding positions of shape (n_samples, n_components).
Returns
-------
torch.Tensor
Optimized final positions of shape (n_samples, n_components),
normalized to zero mean and unit standard deviation.
Notes
-----
Private method, should not be called directly. Used by fit_transform().
Enhancements over parent method:
- Real-time GPU memory monitoring and warnings
- More frequent cache clearing (every 10 iterations)
- Detailed memory usage reporting
- Low memory warnings when free GPU memory < 2GB
"""
positions = initial_positions.clone()
self.logger.info(f"Memory-efficient optimization for {self._n_samples} points...")
# Log initial memory usage
if self.device.type == 'cuda':
mem_reserved = torch.cuda.memory_reserved() / 1e9
mem_total = torch.cuda.get_device_properties(0).total_memory / 1e9
self.logger.info(f"Initial GPU memory: {mem_reserved:.2f}GB used / {mem_total:.1f}GB total")
for iteration in range(self.max_iter_layout):
# Monitor memory before computation
if self.device.type == 'cuda' and iteration % 20 == 0:
mem_available = torch.cuda.mem_get_info()[0] / 1e9
mem_reserved = torch.cuda.memory_reserved() / 1e9
if mem_available < 2.0:
self.logger.warning(f"Low GPU memory: {mem_available:.1f}GB free, {mem_reserved:.1f}GB used")
forces = self._compute_forces(positions, iteration, self.max_iter_layout)
alpha = 1.0 - iteration / self.max_iter_layout
positions.add_(forces, alpha=alpha)
# Logging and memory management
if iteration % 10 == 0:
if self.verbose and iteration % 20 == 0:
force_mag = torch.norm(forces, dim=1).mean().item()
if self.device.type == 'cuda':
mem_reserved = torch.cuda.memory_reserved() / 1e9
mem_available = torch.cuda.mem_get_info()[0] / 1e9
self.logger.info(f"Iteration {iteration}/{self.max_iter_layout}, avg force: {force_mag:.6f}, "
f"GPU memory: {mem_reserved:.1f}GB used, {mem_available:.1f}GB free")
else:
self.logger.info(f"Iteration {iteration}/{self.max_iter_layout}, avg force: {force_mag:.6f}")
if self.device.type == 'cuda':
mem_available = torch.cuda.mem_get_info()[0] / 1e9
if mem_available < 2.0:
torch.cuda.empty_cache()
# Final normalization
positions -= positions.mean(dim=0)
positions /= positions.std(dim=0)
return positions