Source code for graphem_rapids.backends.embedder_pytorch

"""
PyTorch-based implementation of GraphEmbedder with CUDA acceleration.

This module provides the main graph embedding functionality using PyTorch
as the computational backend, with optional CUDA acceleration.
"""

import logging

import numpy as np
import torch
import plotly.graph_objects as go
import scipy.sparse as sp
import scipy.sparse.linalg as spla
from scipy.sparse.csgraph import laplacian
from tqdm import tqdm

from ..utils.memory_management import (
    MemoryManager,
    get_optimal_chunk_size,
    monitor_memory_usage
)

logger = logging.getLogger(__name__)


[docs] class GraphEmbedderPyTorch: """ PyTorch-based graph embedder with CUDA acceleration. This class provides graph embedding using Laplacian initialization followed by force-directed layout optimization, implemented with PyTorch for GPU acceleration and memory efficiency. Attributes ---------- adjacency : scipy.sparse.csr_matrix Sparse adjacency matrix (n_vertices × n_vertices). edges : torch.Tensor Edge list extracted from adjacency matrix as (n_edges, 2) tensor. n : int Number of vertices in the graph. n_components : int Number of components (dimensions) in the embedding space. device : torch.device Computing device (CPU or CUDA). positions : torch.Tensor Current vertex positions as (n_vertices, n_components) tensor. """
[docs] def __init__( self, adjacency, n_components=2, device=None, dtype=torch.float32, L_min=1.0, k_attr=0.2, k_inter=0.5, n_neighbors=10, sample_size=256, batch_size=None, memory_efficient=True, verbose=True, logger_instance=None, seed=None ): """ Initialize the PyTorch GraphEmbedder. Parameters ---------- adjacency : array-like or scipy.sparse matrix Adjacency matrix (n_vertices × n_vertices). Can be sparse or dense. For unweighted graphs, should contain 1s for edges, 0s otherwise. For weighted graphs, contains edge weights (future support). n_components : int, default=2 Number of components (dimensions) in the embedding. device : str or torch.device, optional Computing device. If None, automatically selects GPU if available. dtype : torch.dtype, default=torch.float32 Data type for computations. L_min : float, default=1.0 Minimum spring length. k_attr : float, default=0.2 Attraction force constant. k_inter : float, default=0.5 Intersection repulsion force constant. n_neighbors : int, default=10 Number of nearest neighbors for intersection detection. sample_size : int, default=256 Sample size for kNN computation. batch_size : int, optional Batch size for processing. If None, automatically selects based on available memory. Can be manually set (e.g., batch_size=1024) for custom memory management. memory_efficient : bool, default=True Use memory-efficient algorithms for large graphs. verbose : bool, default=True Enable verbose logging. logger_instance : logging.Logger, optional Custom logger instance. seed : int, optional Random seed for reproducibility. If provided, sets both numpy and torch seeds. """ # Set random seeds for reproducibility if provided if seed is not None: np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) # Setup device if device is None: self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: self.device = torch.device(device) # Setup logging if logger_instance is not None: self.logger = logger_instance else: self.logger = logger if verbose: logging.basicConfig(level=logging.INFO) # Validate and process adjacency matrix adjacency = self._validate_adjacency(adjacency) # Store parameters self.adjacency = adjacency self.n = adjacency.shape[0] # Infer n_vertices from adjacency self.n_components = n_components self.dtype = dtype self.L_min = L_min self.k_attr = k_attr self.k_inter = k_inter self.n_neighbors = n_neighbors self.memory_efficient = memory_efficient self.batch_size = batch_size # None for automatic, or user-defined value # Validate parameters if n_components <= 0: raise ValueError(f"Number of components must be positive, got {n_components}") if k_attr < 0: raise ValueError(f"Attractive force constant k_attr must be non-negative, got {k_attr}") self.verbose = verbose # Extract edges from adjacency matrix edges = self._extract_edges_from_adjacency(adjacency) # Calculate number of edges for sample_size validation self.n_edges = len(edges) # Ensure sample_size doesn't exceed number of edges self.sample_size = min(sample_size, self.n_edges) # Convert edges to tensor self.edges = torch.tensor(edges, device=self.device, dtype=torch.long) # Memory management - detect backend capability self._has_pykeops = self._check_pykeops_availability() backend_type = 'pykeops' if self._has_pykeops else 'torch' # Use user-defined batch_size if provided, otherwise calculate optimal one automatically if self.batch_size is None: self.batch_size = get_optimal_chunk_size(self.n, self.n_components, backend=backend_type) if self.verbose: self.logger.info("Using automatic batch size: %d", self.batch_size) else: if self.verbose: self.logger.info("Using user-defined batch size: %d", self.batch_size) if self.verbose: self.logger.info("Initialized GraphEmbedderPyTorch on %s", self.device) self.logger.info("Graph: %d vertices, %d edges, %dD", self.n, len(self.edges), self.n_components) self.logger.info("KNN backend: %s", backend_type) # Compute initial embedding self._positions = self._compute_laplacian_embedding()
def _validate_adjacency(self, adjacency): """ Validate and convert adjacency matrix to scipy sparse format. Parameters ---------- adjacency : array-like or scipy.sparse matrix Input adjacency matrix Returns ------- scipy.sparse.csr_matrix Validated adjacency matrix in CSR format """ # Handle scipy sparse matrices first if sp.issparse(adjacency): adjacency = adjacency.tocsr() # Ensure CSR format elif isinstance(adjacency, np.ndarray): # Already a numpy array pass else: # Try to convert to numpy array adjacency = np.asarray(adjacency) # Check if square if adjacency.shape[0] != adjacency.shape[1]: raise ValueError(f"Adjacency matrix must be square, got shape {adjacency.shape}") # Check for empty graph if adjacency.shape[0] == 0: raise ValueError("Adjacency matrix cannot be empty") # Convert to scipy sparse for uniform handling if not sp.issparse(adjacency): adjacency = sp.csr_matrix(adjacency) return adjacency def _extract_edges_from_adjacency(self, adjacency): """ Extract edge list from adjacency matrix for undirected graphs. Parameters ---------- adjacency : scipy.sparse matrix Adjacency matrix in sparse format Returns ------- np.ndarray Edge list as (n_edges, 2) array with i < j for undirected graphs """ # Get nonzero entries (edges) rows, cols = adjacency.nonzero() # For undirected graphs, keep only upper triangle (i < j) # This avoids double-counting edges mask = rows < cols edges = np.column_stack([rows[mask], cols[mask]]) if self.verbose and len(edges) == 0: self.logger.warning("No edges found in adjacency matrix") return edges def _check_pykeops_availability(self): """Check if PyKeOps is available and functional.""" try: from pykeops.torch import LazyTensor # pylint: disable=import-outside-toplevel # Test basic functionality test_tensor = torch.randn(2, 3, device=self.device, dtype=self.dtype) x_i = LazyTensor(test_tensor[:1, None, :]) y_j = LazyTensor(test_tensor[None, 1:, :]) _ = ((x_i - y_j) ** 2).sum(-1) return True except (ImportError, RuntimeError, AttributeError): return False def _get_adaptive_chunk_size(self, n_query, n_ref, backend): """ Calculate adaptive chunk size based on available GPU memory and backend. Parameters ---------- n_query : int Number of query points. n_ref : int Number of reference points. backend : str Backend type ('torch' or 'pykeops'). Returns ------- int Optimal chunk size. """ # Start with base batch size base_chunk_size = self.batch_size # Adaptive chunk sizing based on available GPU memory if self.device.type == 'cuda': try: gpu_mem_free, _ = torch.cuda.mem_get_info() # Estimate memory for k-NN: chunk_size * n_ref * bytes_per_element bytes_per_element = 2 if self.dtype == torch.float16 else 4 memory_per_chunk = base_chunk_size * n_ref * bytes_per_element # Backend-specific memory usage patterns if backend == 'pykeops': # PyKeOps is more memory efficient, can use larger chunks memory_fraction = 0.5 chunk_multiplier = 2.0 else: # Standard torch needs more conservative memory usage memory_fraction = 0.3 if self.dtype == torch.float32 else 0.4 chunk_multiplier = 1.0 max_memory = gpu_mem_free * memory_fraction if memory_per_chunk > max_memory: chunk_size = int(max_memory / (n_ref * bytes_per_element)) chunk_size = max(1000, chunk_size) else: chunk_size = int(base_chunk_size * chunk_multiplier) # With FP16, we can use larger chunks if self.dtype == torch.float16: chunk_size = min(chunk_size * 2, 100000) # Ensure chunk size doesn't exceed query size chunk_size = min(chunk_size, n_query) if self.verbose: self.logger.info("Adaptive chunk size for %s: %d (GPU memory: %.1fGB, dtype: %s)", backend, chunk_size, gpu_mem_free/1024**3, self.dtype) return chunk_size except Exception: # pylint: disable=broad-exception-caught # Fallback to base chunk size if memory info unavailable pass return base_chunk_size @property def positions(self): """Get positions as numpy array for API consistency.""" return self._positions.detach().cpu().numpy() @positions.setter def positions(self, value): """Set positions from numpy array or tensor.""" if isinstance(value, np.ndarray): self._positions = torch.tensor(value, dtype=self.dtype, device=self.device) else: self._positions = value.to(device=self.device, dtype=self.dtype) def _compute_laplacian_embedding(self): """ Compute the Laplacian embedding of the graph using scipy. Returns ------- torch.Tensor Initial positions from Laplacian embedding. """ self.logger.info("Computing Laplacian embedding") with MemoryManager(cleanup_on_exit=True): # Use the adjacency matrix we already have # Make symmetric in case it isn't (for undirected graphs) A = self.adjacency + self.adjacency.transpose() A.data = np.ones_like(A.data) # Make unweighted for now # Convert sparse array to matrix for scipy compatibility if hasattr(A, 'toarray'): # It's a sparse array/matrix A = sp.csr_matrix(A) # Compute normalized Laplacian L = laplacian(A, normed=True) # Compute eigenvectors k = self.n_components + 1 try: _, eigenvectors = spla.eigsh(L, k, which='SM') lap_embedding = eigenvectors[:, 1:k] # Skip first eigenvector except Exception as e: # pylint: disable=broad-exception-caught self.logger.warning("Eigendecomposition failed: %s", e) # Fallback to random initialization lap_embedding = np.random.randn(self.n, self.n_components) * 0.1 # Convert to tensor positions = torch.tensor( lap_embedding, device=self.device, dtype=self.dtype ) self.logger.info("Laplacian embedding computed") return positions def _locate_knn_midpoints( self, midpoints, k ): """ Locate k nearest neighbors for edge midpoints. Parameters ---------- midpoints : torch.Tensor Edge midpoints as (n_edges, n_components) tensor. k : int Number of nearest neighbors. Returns ------- Tuple[torch.Tensor, torch.Tensor] KNN indices and sampled indices. """ self.logger.info("Computing kNN for midpoints") E = midpoints.shape[0] sample_size = min(self.sample_size, E) with MemoryManager(): # Sample midpoints if sample_size < E: sampled_indices = torch.randperm(E, device=self.device)[:sample_size] sampled_midpoints = midpoints[sampled_indices] else: sampled_indices = torch.arange(E, device=self.device) sampled_midpoints = midpoints # Compute pairwise distances using chunking for memory efficiency knn_indices = self._compute_knn_chunked( sampled_midpoints, midpoints, k + 1 ) # Remove self-neighbors (first column) knn_indices = knn_indices[:, 1:] self.logger.info("kNN computation completed") return knn_indices, sampled_indices def _compute_knn_chunked( self, query_points, reference_points, k ): """ Compute k-nearest neighbors using chunked processing with intelligent backend selection. Parameters ---------- query_points : torch.Tensor Query points as (n_query, n_components) tensor. reference_points : torch.Tensor Reference points as (n_ref, n_components) tensor. k : int Number of nearest neighbors. Returns ------- torch.Tensor KNN indices as (n_query, k) tensor. """ n_query = query_points.shape[0] n_ref = reference_points.shape[0] n_dims = query_points.shape[1] # Backend selection logic: PyKeOps is slower than PyTorch for high dimensions! # Use PyTorch for high-D, PyKeOps for low-D # Also disable PyKeOps for small graphs due to kernel compilation overhead use_pykeops = (self._has_pykeops and n_dims < 200 and n_query > 1000 and # Only use PyKeOps for larger problems self.device.type == 'cuda' and self.dtype == torch.float32) # PyKeOps doesn't work well with FP16 if n_dims >= 200: self.logger.info("Using PyTorch for k-NN (high dimension: %dD)", n_dims) backend = 'torch' elif use_pykeops: self.logger.info("Using PyKeOps for k-NN (low dimension, GPU available)") backend = 'pykeops' else: self.logger.info("Using PyTorch for k-NN (n_query=%d, n_dims=%d)", n_query, n_dims) backend = 'torch' # Adaptive chunking based on backend and memory chunk_size = self._get_adaptive_chunk_size(n_query, n_ref, backend) try: if backend == 'pykeops': return self._compute_knn_pykeops(query_points, reference_points, k, chunk_size) return self._compute_knn_torch(query_points, reference_points, k, chunk_size) except (ImportError, RuntimeError, AttributeError) as e: if backend == 'pykeops': if self.verbose: self.logger.info("PyKeOps failed, falling back to torch.cdist: %s", str(e)) chunk_size = self._get_adaptive_chunk_size(n_query, n_ref, 'torch') return self._compute_knn_torch(query_points, reference_points, k, chunk_size) raise e def _compute_knn_pykeops( self, query_points, reference_points, k, chunk_size ): """ Compute k-nearest neighbors using PyKeOps for memory efficiency. Parameters ---------- query_points : torch.Tensor Query points as (n_query, n_components) tensor. reference_points : torch.Tensor Reference points as (n_ref, n_components) tensor. k : int Number of nearest neighbors. chunk_size : int Chunk size for processing. Returns ------- torch.Tensor KNN indices as (n_query, k) tensor. """ try: from pykeops.torch import LazyTensor # pylint: disable=import-outside-toplevel except ImportError as exc: raise ImportError("PyKeOps not available") from exc n_query = query_points.shape[0] all_knn_indices = [] for i in range(0, n_query, chunk_size): end_idx = min(i + chunk_size, n_query) query_chunk = query_points[i:end_idx] if n_query > 50000: # Only log for large datasets self.logger.info("Processing PyKeOps chunk %d/%d", i//chunk_size + 1, (n_query + chunk_size - 1)//chunk_size) # Create LazyTensors for symbolic computation x_i = LazyTensor(query_chunk[:, None, :]) # (M, 1, D) y_j = LazyTensor(reference_points[None, :, :]) # (1, N, D) # Compute squared distances symbolically D_ij = ((x_i - y_j) ** 2).sum(-1) # (M, N) # Find k nearest neighbors knn_indices = D_ij.argKmin(k, dim=1) # (M, k) all_knn_indices.append(knn_indices) # Clear GPU memory periodically if self.device.type == 'cuda' and i % (chunk_size * 10) == 0: torch.cuda.empty_cache() return torch.cat(all_knn_indices, dim=0) def _compute_knn_torch( self, query_points, reference_points, k, chunk_size ): """ Compute k-nearest neighbors using torch.cdist with chunked processing. Parameters ---------- query_points : torch.Tensor Query points as (n_query, n_components) tensor. reference_points : torch.Tensor Reference points as (n_ref, n_components) tensor. k : int Number of nearest neighbors. chunk_size : int Chunk size for processing. Returns ------- torch.Tensor KNN indices as (n_query, k) tensor. """ n_query = query_points.shape[0] # Fast path: no chunking needed for small datasets if n_query <= chunk_size: distances = torch.cdist(query_points, reference_points, p=2) _, knn_indices = torch.topk(distances, k, dim=1, largest=False) return knn_indices # Chunked processing for large datasets all_knn_indices = [] for i in range(0, n_query, chunk_size): end_idx = min(i + chunk_size, n_query) query_chunk = query_points[i:end_idx] if n_query > 50000: # Only log for large datasets self.logger.info("Processing torch chunk %d/%d", i//chunk_size + 1, (n_query + chunk_size - 1)//chunk_size) # Compute distances for this chunk distances = torch.cdist(query_chunk, reference_points, p=2) # Find k nearest neighbors _, knn_indices = torch.topk(distances, k, dim=1, largest=False) all_knn_indices.append(knn_indices) # Clean up intermediate tensors del distances # Clear GPU memory periodically (only for very large datasets) if self.device.type == 'cuda' and n_query > 10000 and i % (chunk_size * 10) == 0: torch.cuda.empty_cache() return torch.cat(all_knn_indices, dim=0) @monitor_memory_usage def _compute_spring_forces( self, positions, edges ): """ Compute spring forces between connected vertices. Parameters ---------- positions : torch.Tensor Current vertex positions. edges : torch.Tensor Edge list. Returns ------- torch.Tensor Spring forces for each vertex. """ with MemoryManager(): # Get edge endpoints p1 = positions[edges[:, 0]] p2 = positions[edges[:, 1]] # Compute edge vectors and distances diff = p2 - p1 dist = torch.norm(diff, dim=1, keepdim=True) + 1e-6 # Compute force magnitude (spring law) force_magnitude = -self.k_attr * (dist - self.L_min) # Compute force vectors edge_forces = force_magnitude * (diff / dist) # Accumulate forces on vertices forces = torch.zeros_like(positions) forces.index_add_(0, edges[:, 0], edge_forces) forces.index_add_(0, edges[:, 1], -edge_forces) return forces @monitor_memory_usage def _compute_intersection_forces( self, positions, edges, knn_indices, sampled_indices ): """ Compute intersection repulsion forces between nearby edge pairs. Parameters ---------- positions : torch.Tensor Current vertex positions. edges : torch.Tensor Edge list. knn_indices : torch.Tensor KNN indices for edge midpoints. sampled_indices : torch.Tensor Indices of sampled edges. Returns ------- torch.Tensor Intersection forces for each vertex. """ with MemoryManager(): # Generate edge pairs from KNN results _, n_neighbors = knn_indices.shape candidate_i = sampled_indices.unsqueeze(1).expand(-1, n_neighbors).flatten() candidate_j = knn_indices.flatten() # Filter valid pairs (i < j) valid_mask = candidate_i < candidate_j if not valid_mask.any(): return torch.zeros_like(positions) # Get valid edge pairs valid_i = candidate_i[valid_mask] valid_j = candidate_j[valid_mask] edges_i = edges[valid_i] edges_j = edges[valid_j] # Check for shared vertices (skip connected edges) share_mask = ( (edges_i[:, 0] == edges_j[:, 0]) | (edges_i[:, 0] == edges_j[:, 1]) | (edges_i[:, 1] == edges_j[:, 0]) | (edges_i[:, 1] == edges_j[:, 1]) ) interaction_mask = valid_mask[valid_mask].clone() interaction_mask[share_mask] = False if not interaction_mask.any(): return torch.zeros_like(positions) # Filter to interacting pairs edges_i = edges_i[interaction_mask] edges_j = edges_j[interaction_mask] # Get edge endpoints p1 = positions[edges_i[:, 0]] p2 = positions[edges_i[:, 1]] q1 = positions[edges_j[:, 0]] q2 = positions[edges_j[:, 1]] # Check for line segment intersections using orientation test intersect_mask = self._check_line_intersections(p1, p2, q1, q2) if not intersect_mask.any(): return torch.zeros_like(positions) # Filter to actually intersecting edges edges_i = edges_i[intersect_mask] edges_j = edges_j[intersect_mask] p1 = p1[intersect_mask] p2 = p2[intersect_mask] q1 = q1[intersect_mask] q2 = q2[intersect_mask] # Compute intersection midpoints inter_midpoints = (p1 + p2 + q1 + q2) / 4.0 # Compute repulsion forces forces = torch.zeros_like(positions) for vertex_pos, edge_vertices in [(p1, edges_i[:, 0]), (p2, edges_i[:, 1]), (q1, edges_j[:, 0]), (q2, edges_j[:, 1])]: # Compute repulsion from intersection points diff = vertex_pos - inter_midpoints dist = torch.norm(diff, dim=1, keepdim=True) + 1e-6 repulsion = self.k_inter * diff / (dist ** 2) forces.index_add_(0, edge_vertices, repulsion) return forces def _check_line_intersections( self, p1, p2, q1, q2 ): """ Check if line segments (p1,p2) and (q1,q2) intersect. Parameters ---------- p1, p2 : torch.Tensor Endpoints of first line segment. q1, q2 : torch.Tensor Endpoints of second line segment. Returns ------- torch.Tensor Boolean mask indicating which pairs intersect. """ def orientation(a, b, c): """Compute orientation of ordered triplet (a, b, c).""" return (b[..., 0] - a[..., 0]) * (c[..., 1] - a[..., 1]) - \ (b[..., 1] - a[..., 1]) * (c[..., 0] - a[..., 0]) # Compute orientations o1 = orientation(p1, p2, q1) o2 = orientation(p1, p2, q2) o3 = orientation(q1, q2, p1) o4 = orientation(q1, q2, p2) # Check intersection condition intersect = (o1 * o2 < 0) & (o3 * o4 < 0) return intersect
[docs] def update_positions(self): """Update vertex positions based on computed forces.""" self.logger.info("Updating vertex positions") with MemoryManager(): # Compute spring forces spring_forces = self._compute_spring_forces(self._positions, self.edges) # Compute edge midpoints midpoints = (self._positions[self.edges[:, 0]] + self._positions[self.edges[:, 1]]) / 2.0 # Find nearest neighbors for intersection detection knn_indices, sampled_indices = self._locate_knn_midpoints(midpoints, self.n_neighbors) # Compute intersection forces inter_forces = self._compute_intersection_forces( self._positions, self.edges, knn_indices, sampled_indices ) # Combine forces total_forces = spring_forces + inter_forces # Update positions new_positions = self._positions + total_forces # Normalize positions (center and scale) new_positions = new_positions - torch.mean(new_positions, dim=0, keepdim=True) std = torch.std(new_positions, dim=0, keepdim=True) + 1e-6 self._positions = new_positions / std self.logger.info("Position update completed")
[docs] def run_layout(self, num_iterations=100): """ Run the force-directed layout algorithm. Parameters ---------- num_iterations : int, default=100 Number of iterations to run. Returns ------- torch.Tensor Final vertex positions. """ self.logger.info("Running layout for %d iterations", num_iterations) with MemoryManager(cleanup_on_exit=True): for iteration in tqdm(range(num_iterations), desc="Layout iterations"): self.update_positions() # Optional: log progress if self.verbose and (iteration + 1) % 10 == 0: self.logger.info("Completed iteration %d/%d", iteration + 1, num_iterations) self.logger.info("Layout computation completed") return self.positions
[docs] def get_positions(self): """ Get vertex positions as numpy array. Returns ------- np.ndarray Vertex positions. """ return self.positions # Uses property which already returns numpy array
[docs] def display_layout( self, edge_width=1, node_size=3, node_colors=None ): """ Display the graph embedding using Plotly. Parameters ---------- edge_width : float, default=1 Width of the edges. node_size : float, default=3 Size of the nodes. node_colors : array-like, optional Colors for each vertex. """ self.logger.info("Displaying layout") if self.n_components == 2: self._display_layout_2d(edge_width, node_size, node_colors) elif self.n_components == 3: self._display_layout_3d(edge_width, node_size, node_colors) else: raise ValueError("Can only display 2D or 3D layouts")
def _display_layout_2d( self, edge_width, node_size, node_colors ): """Display 2D graph embedding.""" pos = self.get_positions() edges_np = self.edges.cpu().numpy() # Create edge traces x_edges, y_edges = [], [] for i, j in edges_np: x_edges.extend([pos[i, 0], pos[j, 0], None]) y_edges.extend([pos[i, 1], pos[j, 1], None]) edge_trace = go.Scatter( x=x_edges, y=y_edges, mode='lines', line={'color': 'gray', 'width': edge_width}, hoverinfo='none' ) # Create node trace node_trace = go.Scatter( x=pos[:, 0], y=pos[:, 1], mode='markers', marker={ 'color': node_colors if node_colors is not None else 'red', 'colorscale': 'Bluered', 'size': node_size, 'colorbar': {'title': 'Node Label'}, 'showscale': node_colors is not None }, hoverinfo='none' ) # Create figure fig = go.Figure(data=[edge_trace, node_trace]) fig.update_layout( title="2D Graph Embedding (PyTorch)", xaxis={'title': 'X', 'showgrid': False, 'zeroline': False}, yaxis={'title': 'Y', 'showgrid': False, 'zeroline': False}, showlegend=False, width=800, height=800 ) fig.show() def _display_layout_3d( self, edge_width, node_size, node_colors ): """Display 3D graph embedding.""" pos = self.get_positions() edges_np = self.edges.cpu().numpy() # Create edge traces x_edges, y_edges, z_edges = [], [], [] for i, j in edges_np: x_edges.extend([pos[i, 0], pos[j, 0], None]) y_edges.extend([pos[i, 1], pos[j, 1], None]) z_edges.extend([pos[i, 2], pos[j, 2], None]) edge_trace = go.Scatter3d( x=x_edges, y=y_edges, z=z_edges, mode='lines', line={'color': 'gray', 'width': edge_width}, hoverinfo='none' ) # Create node trace node_trace = go.Scatter3d( x=pos[:, 0], y=pos[:, 1], z=pos[:, 2], mode='markers', marker={ 'color': node_colors if node_colors is not None else 'red', 'colorscale': 'Bluered', 'size': node_size, 'colorbar': {'title': 'Node Label'}, 'showscale': node_colors is not None }, hoverinfo='none' ) # Create figure fig = go.Figure(data=[edge_trace, node_trace]) fig.update_layout( title="3D Graph Embedding (PyTorch)", scene={'xaxis': {'title': 'X'}, 'yaxis': {'title': 'Y'}, 'zaxis': {'title': 'Z'}}, showlegend=False, width=800, height=800 ) fig.show() def __repr__(self): """String representation of the embedder.""" return (f"GraphEmbedderPyTorch(n_vertices={self.n}, n_components={self.n_components}, " f"device={self.device}, memory_efficient={self.memory_efficient})")