# utils.py
"""
Utility classes and functions for dire-rapids package.
This module provides:
- ReducerConfig: Configuration dataclass for dimensionality reduction algorithms
- ReducerRunner: General-purpose runner for dimensionality reduction benchmarking
- Dataset loading utilities for sklearn, cytof, DiRe geometric datasets, and more
"""
import inspect
import os
import re
import time
import gzip
import shutil
import urllib.request
from dataclasses import dataclass
from pathlib import Path
import numpy as np
import pandas as pd
from sklearn import datasets as skds
try:
from scipy import sparse as sp
except ImportError:
sp = None # sklearn normally pulls scipy in; keep soft guard
def _identity_transform(X, y):
return X, y
# --------- minimal display helpers (so .visualize renders in Colab) ---------
def _safe_init_plotly_renderer():
try:
import plotly.io as pio # pylint: disable=import-outside-toplevel
if pio.renderers.default in (None, "auto"):
try:
import google.colab # noqa: F401 # pylint: disable=import-outside-toplevel,unused-import
pio.renderers.default = "colab"
except ImportError:
pio.renderers.default = "notebook_connected"
except ImportError:
pass
def _display_obj(obj): # pylint: disable=too-many-return-statements
"""Display an object using appropriate renderer (plotly, matplotlib, IPython)."""
if obj is None:
return False
if isinstance(obj, (list, tuple)):
shown = False
for it in obj:
shown = _display_obj(it) or shown
return shown
# Plotly
try:
import plotly.graph_objects as go # pylint: disable=import-outside-toplevel
if isinstance(obj, go.Figure):
_safe_init_plotly_renderer()
obj.show()
return True
except (ImportError, AttributeError):
pass
# Matplotlib
try:
import matplotlib.pyplot as plt # pylint: disable=import-outside-toplevel
from matplotlib.figure import Figure # pylint: disable=import-outside-toplevel
from matplotlib.axes import Axes # pylint: disable=import-outside-toplevel
if isinstance(obj, (Figure, Axes)):
plt.show()
return True
except (ImportError, AttributeError):
pass
# HTML / str
if isinstance(obj, (str, bytes)):
s = obj.decode("utf-8", "ignore") if isinstance(obj, bytes) else obj
if "<" in s and ">" in s:
try:
from IPython.display import display, HTML # pylint: disable=import-outside-toplevel
display(HTML(s))
except ImportError:
print(s) # Fallback to print if IPython not available
else:
print(s)
return True
try:
try:
from IPython.display import display # pylint: disable=import-outside-toplevel
display(obj)
except ImportError:
print(obj) # Fallback to print if IPython not available
return True
except (ImportError, AttributeError, TypeError):
return False
# --------- sklearn resolution ---------
_SKLEARN_ALIASES = {
# loaders
"iris": "load_iris",
"digits": "load_digits",
"wine": "load_wine",
"breast_cancer": "load_breast_cancer",
"diabetes": "load_diabetes",
"linnerud": "load_linnerud",
# generators
"blobs": "make_blobs",
"classification": "make_classification",
"multilabel_classification": "make_multilabel_classification",
"moons": "make_moons",
"circles": "make_circles",
"s_curve": "make_s_curve",
"swiss_roll": "make_swiss_roll",
"gaussian_quantiles": "make_gaussian_quantiles",
"low_rank_matrix": "make_low_rank_matrix",
"spd_matrix": "make_spd_matrix",
"sparse_spd_matrix": "make_sparse_spd_matrix",
}
def _normalize_key(s):
return re.sub(r"[^a-z0-9_]+", "_", s.strip().lower())
def _resolve_sklearn_function(name):
n = _normalize_key(name)
if n.startswith(("load_", "fetch_", "make_")):
fn = getattr(skds, n, None)
if callable(fn):
return n, fn
alias = _SKLEARN_ALIASES.get(n)
if alias and callable(getattr(skds, alias, None)):
return alias, getattr(skds, alias)
for pref in ("load_", "fetch_", "make_"):
cand = pref + n
fn = getattr(skds, cand, None)
if callable(fn):
return cand, fn
candidates = [
(attr, getattr(skds, attr))
for attr in dir(skds)
if attr.lower().endswith(n) and callable(getattr(skds, attr))
]
if len(candidates) == 1:
return candidates[0]
if candidates:
names = ", ".join(a for a, _ in candidates[:6])
raise ValueError(f"Ambiguous sklearn dataset '{name}'. Candidates: {names} ...")
all_names = ", ".join(a for a in dir(skds) if a.startswith(("load_", "fetch_", "make_")))
raise ValueError(f"Unknown sklearn dataset '{name}'. Available include: {all_names}")
def _to_Xy_from_obj(obj):
if isinstance(obj, (tuple, list)) and len(obj) >= 1:
X = obj[0]
y = obj[1] if len(obj) > 1 else None
return _coerce_Xy(X, y)
if hasattr(obj, "get"):
data = obj.get("data", None)
target = obj.get("target", None)
images = obj.get("images", None)
if data is None and images is not None:
imgs = np.asarray(images)
data = imgs.reshape(len(imgs), -1)
return _coerce_Xy(data, target)
if hasattr(obj, "shape"):
return _coerce_Xy(obj, None)
raise ValueError("Unsupported sklearn return type; cannot coerce to (X, y).")
def _coerce_Xy(X, y):
if isinstance(X, list) and X and isinstance(X[0], str):
raise TypeError("Loaded dataset contains text data; vectorize first.")
if sp is not None and sp.issparse(X):
X = X.toarray()
X = np.asarray(X, dtype=np.float32)
if y is None:
return X, None
y = np.asarray(y)
if y.dtype.kind in {"U", "S", "O"}:
uniq = {v: i for i, v in enumerate(np.unique(y))}
y = np.array([uniq[v] for v in y], dtype=np.int32)
return X, y
def _load_sklearn_any(name, **kwargs):
_, fn = _resolve_sklearn_function(name)
try:
sig = inspect.signature(fn)
if "return_X_y" in sig.parameters:
obj = fn(return_X_y=True, **kwargs)
X, y = _to_Xy_from_obj(obj)
else:
obj = fn(**kwargs)
X, y = _to_Xy_from_obj(obj)
except TypeError:
obj = fn()
X, y = _to_Xy_from_obj(obj)
return X, y
# --------- file loader ---------
def _load_file(path, **kwargs):
path = str(path)
ext = Path(path).suffix.lower()
if ext == ".csv":
df = pd.read_csv(path)
label_col = kwargs.pop("label_column", None)
if label_col and label_col in df.columns:
y = df[label_col].to_numpy()
X = df.drop(columns=[label_col]).to_numpy(dtype=np.float32)
else:
y = None
X = df.to_numpy(dtype=np.float32)
return X, y
if ext == ".parquet":
df = pd.read_parquet(path)
label_col = kwargs.pop("label_column", None)
if label_col and label_col in df.columns:
y = df[label_col].to_numpy()
X = df.drop(columns=[label_col]).to_numpy(dtype=np.float32)
else:
y = None
X = df.to_numpy(dtype=np.float32)
return X, y
if ext == ".npy":
X = np.load(path, mmap_mode="r")
y = None
labels_path = kwargs.pop("labels_path", None)
if labels_path:
y = np.load(labels_path, mmap_mode="r")
return np.asarray(X, dtype=np.float32), y
if ext == ".npz":
f = np.load(path, mmap_mode="r")
if "X" not in f:
raise ValueError(".npz must contain key 'X' (and optionally 'y').")
X = np.asarray(f["X"], dtype=np.float32)
y = f["y"] if "y" in f else None
return X, y
raise ValueError(f"Unsupported file type '{ext}'. Use .csv, .npy, .npz, or .parquet.")
# --------- DiRe geometric datasets ---------
def rand_point_disk(n_features, n_samples=1):
"""Generate uniformly distributed points in n-dimensional unit disk."""
prepts = np.random.randn(n_samples, n_features)
prenorms = np.linalg.norm(prepts, axis=1).reshape(-1, 1)
rads = np.sqrt(np.random.rand(n_samples)).reshape(-1, 1)
pts = prepts * rads / prenorms
return pts
def rand_point_sphere(n_features, n_samples=1):
"""Generate uniformly distributed points on n-dimensional unit sphere."""
prepts = np.random.randn(n_samples, n_features)
prenorms = np.linalg.norm(prepts, axis=1).reshape(-1, 1)
pts = prepts / prenorms
return pts
class elgen:
"""Ellipsoid generator - transforms sphere points to ellipsoid."""
def __init__(self, a):
a = np.array(a)
themat = np.diag(1 / (a * a))
L = np.linalg.inv(np.linalg.cholesky(themat).T)
self.L = L
def __call__(self, ar):
return (self.L @ ar.T).T
def rand_point_ell(semi_axes, n_features, n_samples=1):
"""Generate uniformly distributed points on n-dimensional ellipsoid with semi-axes."""
spts = rand_point_sphere(n_features, n_samples)
eg = elgen(semi_axes)
return eg(spts)
def _load_dire_dataset(name, **kwargs):
"""
Load DiRe geometric datasets.
Supported:
- 'disk_uniform': Uniform in n-dimensional unit disk
- 'sphere_uniform': Uniform on n-dimensional unit sphere
- 'ellipsoid_uniform': Uniform on n-dimensional ellipsoid
Options:
- n_samples (default 1000)
- n_features (default 10)
- semi_axes (for ellipsoid, default [1, 2, ..., n])
- random_state
"""
key = _normalize_key(name)
n_samples = kwargs.pop('n_samples', 1000)
n_features = kwargs.pop('n_features', 10)
random_state = kwargs.pop('random_state', None)
if random_state is not None:
np.random.seed(random_state)
if key == 'disk_uniform':
X = rand_point_disk(n_features, n_samples)
elif key == 'sphere_uniform':
X = rand_point_sphere(n_features, n_samples)
elif key == 'ellipsoid_uniform':
semi_axes = kwargs.pop('semi_axes', None)
if semi_axes is not None:
n_features = len(semi_axes) # Infer n_features from semi_axes
else:
semi_axes = list(range(1, n_features + 1)) # Default semi_axes
X = rand_point_ell(semi_axes, n_features, n_samples)
else:
raise ValueError(
f"Unknown DiRe dataset '{name}'. Options: 'disk_uniform', 'sphere_uniform', 'ellipsoid_uniform'"
)
return X.astype(np.float32), None
# --------- cytof scheme (Levine13/32) ---------
_DEF_CACHE = os.path.join(os.path.expanduser("~"), ".cache", "reducer_runner", "cytof")
os.makedirs(_DEF_CACHE, exist_ok=True)
def _download(url, dest, *, overwrite=False):
if (not overwrite) and os.path.exists(dest):
return dest
tmp = dest + ".part"
os.makedirs(os.path.dirname(dest), exist_ok=True)
urllib.request.urlretrieve(url, tmp)
os.replace(tmp, dest)
return dest
def _safe_gunzip(path):
if path.endswith(".gz"):
out = path[:-3]
if not os.path.exists(out):
with gzip.open(path, "rb") as f_in, open(out, "wb") as f_out:
shutil.copyfileobj(f_in, f_out)
return out
return path
_CYTOF_REGISTRY = {
"levine13": {
"urls": [
"https://raw.githubusercontent.com/lmweber/benchmark-data-Levine-13-dim/master/data/Levine_13dim.fcs",
"https://raw.githubusercontent.com/lmweber/benchmark-data-Levine-13-dim/master/data/Levine_13dim.txt",
],
"label_column": "label",
"drop_columns": ("label", "individual"),
},
"levine32": {
"urls": [
"https://raw.githubusercontent.com/lmweber/benchmark-data-Levine-32-dim/master/data/Levine_32dim.fcs",
],
"label_column": "label",
"drop_columns": ("label", "individual"),
},
}
def _load_cytof(name, **kwargs):
"""
CyTOF loader:
- 'levine13'
- 'levine32'
via built-in URLs/caching
Supports .txt/.tsv/.csv (pandas).
Options:
- url / file / cache_dir
- label_column (for txt/csv/tsv)
- drop_columns
- arcsinh_cofactor (if raw)
"""
key = _normalize_key(name)
spec = _CYTOF_REGISTRY.get(key)
if spec is None:
raise ValueError(f"Unknown cytof dataset '{name}'. Options: {tuple(_CYTOF_REGISTRY.keys())}")
cache_dir = kwargs.pop("cache_dir", _DEF_CACHE)
url = kwargs.pop("url", None)
label_col = kwargs.pop("label_column", spec.get("label_column", "label"))
drop_cols = tuple(kwargs.pop("drop_columns", spec.get("drop_columns", (label_col,))))
drop_unassigned = bool(kwargs.pop("drop_unassigned", False))
arcsinh_cofactor = kwargs.pop("arcsinh_cofactor", None)
local_path = kwargs.pop("file", None)
# Resolve local or download
if local_path is None:
urls = [url] if url else spec.get("urls", [])
if not urls:
raise ValueError(f"cytof:{name} requires 'url' or local 'file' path.")
last_err = None
for u in urls:
try:
fname = os.path.join(cache_dir, os.path.basename(u.split("?")[0]))
local_path = _download(u, fname)
break
except Exception as e: # pylint: disable=broad-exception-caught
last_err = e
local_path = None
if local_path is None:
raise RuntimeError(f"Failed to download cytof:{name}: {last_err}") from last_err
path = _safe_gunzip(local_path)
ext = Path(path).suffix.lower()
# ---------- FCS via flowio ----------
if ext == ".fcs":
try:
import flowio # pylint: disable=import-outside-toplevel
except ImportError as exc:
raise ImportError("flowio required for FCS files. Install with: pip install flowio") from exc
fcs = flowio.FlowData(path)
data = fcs.as_array() # Get 2D numpy array with preprocessing
# Get channel names from pnn_labels (parameter names)
channel_names = fcs.pnn_labels if fcs.pnn_labels else [f'Ch{i}' for i in range(fcs.channel_count)]
# Create DataFrame from FCS data
df = pd.DataFrame(data, columns=channel_names)
# Drop rows with null labels if requested
if drop_unassigned and label_col in df.columns:
before = len(df)
df = df[df[label_col].notna()].copy()
after = len(df)
print(f"[cytof] dropped {before - after} rows with null labels")
y = df[label_col].to_numpy() if label_col in df.columns else None
drop = [c for c in drop_cols if c in df.columns]
Xdf = df.drop(columns=drop, errors="ignore").select_dtypes(include=[np.number])
X = Xdf.to_numpy(dtype=np.float32, copy=False)
if (arcsinh_cofactor is not None) and arcsinh_cofactor > 0:
X = np.arcsinh(X / float(arcsinh_cofactor)).astype(np.float32)
# map string labels to ints
if y is not None:
y = np.asarray(y)
if y.dtype.kind in {"U", "S", "O"}:
uniq = {v: i for i, v in enumerate(np.unique(y))}
y = np.array([uniq[v] for v in y], dtype=np.int32)
elif y.dtype.kind == "f": # floating point labels
y = y.astype(np.int32)
return X, y
# ---------- TXT/TSV/CSV via pandas ----------
if ext in (".txt", ".tsv", ".csv"):
sep = "\t" if ext in (".txt", ".tsv") else ","
df = pd.read_csv(path, sep=sep)
# Drop rows with null labels if requested
if drop_unassigned and label_col in df.columns:
before = len(df)
df = df[df[label_col].notna()].copy()
after = len(df)
print(f"[cytof] dropped {before - after} rows with null labels")
y = df[label_col].to_numpy() if label_col in df.columns else None
drop = [c for c in drop_cols if c in df.columns]
Xdf = df.drop(columns=drop, errors="ignore").select_dtypes(include=[np.number])
X = Xdf.to_numpy(dtype=np.float32, copy=False)
if (arcsinh_cofactor is not None) and arcsinh_cofactor > 0:
X = np.arcsinh(X / float(arcsinh_cofactor)).astype(np.float32)
# map string labels to ints
if y is not None:
y = np.asarray(y)
if y.dtype.kind in {"U", "S", "O"}:
uniq = {v: i for i, v in enumerate(np.unique(y))}
y = np.array([uniq[v] for v in y], dtype=np.int32)
elif y.dtype.kind == "f": # floating point labels
y = y.astype(np.int32)
return X, y
raise ValueError(f"Unsupported cytof file: {path} (use .fcs, .txt/.tsv, or .csv)")
# --------- ReducerConfig ---------
[docs]
@dataclass
class ReducerConfig:
"""
Configuration for a dimensionality reduction algorithm.
All fields are mutable and can be changed after creation:
config.visualize = True
config.categorical_labels = False
config.max_points = 20000
"""
name: str
reducer_class: type
reducer_kwargs: dict
visualize: bool = False
categorical_labels: bool = True # False for regression-style labels (swiss_roll, etc.)
max_points: int = 10000 # Max points for visualization (subsamples if larger)
# --------- selector parsing ---------
def _parse_selector(selector):
s = selector.strip()
p = Path(s)
if p.exists() or re.search(r"\.(csv|np[yz]|parquet)$", s, re.I):
return "file", s
m = re.match(r"^(?P<scheme>[A-Za-z0-9_]+)[:\.](?P<name>.+)$", s)
if m:
return m.group("scheme").lower(), m.group("name").strip()
return "sklearn", s
# --------- Runner ---------
[docs]
@dataclass
class ReducerRunner:
"""
General-purpose runner for dimensionality reduction algorithms.
Supports:
- DiRe (create_dire, DiRePyTorch, DiRePyTorchMemoryEfficient, DiReCuVS)
- cuML (UMAP, TSNE)
- scikit-learn (any TransformerMixin-compatible class)
Parameters
----------
config : ReducerConfig
Configuration object containing reducer_class, reducer_kwargs, name, and visualize flag.
"""
config: ReducerConfig
[docs]
def __post_init__(self):
"""Validate that config is provided."""
if self.config is None:
raise ValueError("Must provide 'config' (ReducerConfig)")
def _get_reducer_info(self):
"""Extract reducer info from config."""
return (
self.config.name,
self.config.reducer_class,
self.config.reducer_kwargs,
self.config.visualize,
self.config.categorical_labels,
self.config.max_points
)
[docs]
def run(self, dataset, *, dataset_kwargs=None, transform=None):
"""
Run dimensionality reduction on specified dataset.
Parameters
----------
dataset : str
Dataset selector (sklearn:name, openml:name, cytof:name, dire:name, file:path)
dataset_kwargs : dict, optional
Arguments for dataset loader
transform : callable, optional
Custom transform function (X, y) -> (X', y')
Returns
-------
dict
Results containing:
- embedding: reduced data
- labels: data labels
- reducer: fitted reducer instance
- fit_time_sec: time taken for fit_transform
- dataset_info: dataset metadata
"""
# Get reducer configuration
reducer_name, reducer_class, reducer_kwargs, should_visualize, categorical_labels, max_points = self._get_reducer_info()
scheme, name = _parse_selector(dataset)
dataset_kwargs = dataset_kwargs or {}
if scheme == "sklearn":
X, y = _load_sklearn_any(name, **dataset_kwargs)
elif scheme == "file":
X, y = _load_file(name, **dataset_kwargs)
elif scheme == "openml":
from sklearn.datasets import fetch_openml # pylint: disable=import-outside-toplevel
try:
data_id = int(str(name))
ds = fetch_openml(data_id=data_id, return_X_y=True, **dataset_kwargs)
except (ValueError, TypeError):
ds = fetch_openml(name=name, return_X_y=True, **dataset_kwargs)
X, y = _coerce_Xy(ds[0], ds[1])
elif scheme == "cytof":
X, y = _load_cytof(name, **dataset_kwargs)
elif scheme == "dire":
X, y = _load_dire_dataset(name, **dataset_kwargs)
else:
raise ValueError(f"Unsupported scheme '{scheme}'. Use 'sklearn', 'openml', 'cytof', 'dire', 'file'.")
T = transform or _identity_transform
X, y = T(X, y)
# Instantiate reducer (handles both classes and factory functions)
if callable(reducer_class):
reducer = reducer_class(**reducer_kwargs)
else:
raise TypeError(f"reducer_class must be callable, got {type(reducer_class)}")
t0 = time.perf_counter()
embedding = reducer.fit_transform(X)
t1 = time.perf_counter()
# Handle visualization
if should_visualize:
# Only use ReducerRunner's plotly visualization (not the reducer's built-in visualize)
n_dims = embedding.shape[1] if len(embedding.shape) > 1 else 1
if n_dims in (2, 3):
try:
self._visualize_with_plotly(embedding, y, reducer_name, n_dims, categorical_labels, max_points)
except Exception as e: # pylint: disable=broad-exception-caught
print(f"[WARNING] plotly visualization failed: {e}")
return {
"embedding": embedding,
"labels": y,
"reducer": reducer,
"fit_time_sec": float(t1 - t0),
"dataset_info": {
"selector": dataset,
"n_samples": int(X.shape[0]),
"n_features": int(X.shape[1]),
},
}
def _visualize_with_plotly(self, embedding, labels, title, n_dims, categorical_labels=True, max_points=10000):
"""
Create and display plotly visualization for 2D or 3D embeddings.
Uses WebGL rendering (Scattergl) for performance. Automatically subsamples
to max_points if dataset is larger.
"""
try:
import plotly.graph_objects as go # pylint: disable=import-outside-toplevel
except ImportError:
print("[WARNING] plotly not installed. Install with: pip install plotly")
return
_safe_init_plotly_renderer()
n_points = embedding.shape[0]
# Subsample if needed
if n_points > max_points:
rng = np.random.default_rng(42)
subsample_idx = rng.choice(n_points, max_points, replace=False)
embedding_vis = embedding[subsample_idx]
labels_vis = labels[subsample_idx] if labels is not None else None
else:
embedding_vis = embedding
labels_vis = labels
if n_dims == 2:
# Use Scattergl for WebGL acceleration
if labels_vis is not None:
if not categorical_labels:
fig = go.Figure(data=go.Scattergl(
x=embedding_vis[:, 0],
y=embedding_vis[:, 1],
mode='markers',
marker={
"size": 4,
"color": labels_vis,
"colorscale": 'Viridis',
"colorbar": {"title": "Label Value"},
"showscale": True,
"opacity": 0.8
}
))
else:
unique_labels = np.unique(labels_vis)
if len(unique_labels) > 20:
label_to_idx = {lbl: idx for idx, lbl in enumerate(unique_labels)}
colors = np.array([label_to_idx[lbl] for lbl in labels_vis])
fig = go.Figure(data=go.Scattergl(
x=embedding_vis[:, 0],
y=embedding_vis[:, 1],
mode='markers',
marker={
"size": 4,
"color": colors,
"colorscale": 'Viridis',
"showscale": True,
"opacity": 0.8
},
text=[f"Label: {lbl}" for lbl in labels_vis],
hovertemplate='%{text}<extra></extra>'
))
else:
fig = go.Figure()
for label in unique_labels:
mask = labels_vis == label
fig.add_trace(go.Scattergl(
x=embedding_vis[mask, 0],
y=embedding_vis[mask, 1],
mode='markers',
name=str(label),
marker={"size": 4, "opacity": 0.8}
))
else:
fig = go.Figure(data=go.Scattergl(
x=embedding_vis[:, 0],
y=embedding_vis[:, 1],
mode='markers',
marker={"size": 4, "opacity": 0.7}
))
fig.update_layout(
title=f"{title} - 2D Embedding",
xaxis_title="Dimension 1",
yaxis_title="Dimension 2",
width=800,
height=600,
hovermode='closest'
)
elif n_dims == 3:
if labels_vis is not None:
if not categorical_labels:
fig = go.Figure(data=go.Scatter3d(
x=embedding_vis[:, 0],
y=embedding_vis[:, 1],
z=embedding_vis[:, 2],
mode='markers',
marker={
"size": 2,
"color": labels_vis,
"colorscale": 'Viridis',
"colorbar": {"title": "Label Value"},
"showscale": True,
"opacity": 0.8
}
))
else:
unique_labels = np.unique(labels_vis)
if len(unique_labels) > 20:
label_to_idx = {lbl: idx for idx, lbl in enumerate(unique_labels)}
colors = np.array([label_to_idx[lbl] for lbl in labels_vis])
fig = go.Figure(data=go.Scatter3d(
x=embedding_vis[:, 0],
y=embedding_vis[:, 1],
z=embedding_vis[:, 2],
mode='markers',
marker={
"size": 2,
"color": colors,
"colorscale": 'Viridis',
"showscale": True,
"opacity": 0.8
},
text=[f"Label: {lbl}" for lbl in labels_vis],
hovertemplate='%{text}<extra></extra>'
))
else:
fig = go.Figure()
for label in unique_labels:
mask = labels_vis == label
fig.add_trace(go.Scatter3d(
x=embedding_vis[mask, 0],
y=embedding_vis[mask, 1],
z=embedding_vis[mask, 2],
mode='markers',
name=str(label),
marker={"size": 2, "opacity": 0.8}
))
else:
fig = go.Figure(data=go.Scatter3d(
x=embedding_vis[:, 0],
y=embedding_vis[:, 1],
z=embedding_vis[:, 2],
mode='markers',
marker={"size": 2, "opacity": 0.7}
))
fig.update_layout(
title=f"{title} - 3D Embedding",
scene={
"xaxis_title": "Dimension 1",
"yaxis_title": "Dimension 2",
"zaxis_title": "Dimension 3"
},
width=900,
height=700
)
fig.show()
[docs]
@staticmethod
def available_sklearn():
"""Return available sklearn dataset loaders, fetchers, and generators."""
loads = tuple(a for a in dir(skds) if a.startswith("load_") and callable(getattr(skds, a)))
fetches = tuple(a for a in dir(skds) if a.startswith("fetch_") and callable(getattr(skds, a)))
makes = tuple(a for a in dir(skds) if a.startswith("make_") and callable(getattr(skds, a)))
return {"load": loads, "fetch": fetches, "make": makes}
[docs]
@staticmethod
def available_cytof():
"""Return available CyTOF datasets."""
return tuple(_CYTOF_REGISTRY.keys())